grib2io.xarray_backend

grib2io Backend Engine for Xarray

grib2io provides a Xarray backend entrypoint for decoding many GRIB2 messages from a single file or many files and represented as Xarray DataArray objects and collected along common coordinates as Datasets and DataTrees.

The grib2io.xarray_backend engine API is experimental. Its interface and behavior may change in future releases, which could affect backward compatibility.

Users are encouraged to treat this backend as subject to change and to pin their grib2io version if depending on its current implementation details.

   1"""
   2grib2io Backend Engine for Xarray
   3=================================
   4grib2io provides a Xarray backend entrypoint for decoding many GRIB2 messages
   5from a single file or many files and represented as Xarray DataArray objects and
   6collected along common coordinates as Datasets and DataTrees.
   7
   8.. warning::
   9
  10   The ``grib2io.xarray_backend`` engine API is **experimental**.
  11   Its interface and behavior may change in future releases,
  12   which could affect backward compatibility.
  13
  14   Users are encouraged to treat this backend as subject to change
  15   and to pin their ``grib2io`` version if depending on its current
  16   implementation details.
  17"""
  18from grib2io._grib2io import _data
  19from grib2io import Grib2Message, Grib2GridDef, msgs_from_index
  20import grib2io
  21from xarray.backends.locks import SerializableLock
  22from xarray.core import indexing
  23from xarray.backends import (
  24    BackendArray,
  25    BackendEntrypoint,
  26)
  27from copy import copy
  28from collections import defaultdict
  29from dataclasses import dataclass, field, astuple
  30import importlib.metadata
  31import itertools
  32import logging
  33import typing
  34import warnings
  35
  36from . import tables
  37
  38import numpy as np
  39import pandas as pd
  40import xarray as xr
  41import re
  42from pyproj import CRS
  43import datetime
  44
  45# Check if xarray version supports DataTree
  46_HAS_DATATREE = False
  47try:
  48    # Try importing DataTree to check if it's available
  49    xarray_version = importlib.metadata.version('xarray')
  50    xarray_parts = [int(x) if x.isdigit() else x for x in xarray_version.split('.')]
  51    min_version_parts = [2024, 10, 0]
  52    _HAS_DATATREE = xarray_parts >= min_version_parts
  53except (ImportError, ValueError):
  54    _HAS_DATATREE = False
  55
  56_logger = logging.getLogger(__name__)
  57
  58_LOCK = SerializableLock()
  59
  60_LEVEL_NAME_MAPPING = grib2io.tables.get_table('4.5.grib2io.level.name')
  61
  62_TREE_HIERARCHY_LEVELS = [
  63    "typeOfFirstFixedSurface",
  64    "valueOfFirstFixedSurface",
  65    "productDefinitionTemplateNumber",
  66    "perturbationNumber",
  67    "leadTime",
  68    "duration",
  69    "percentileValue",
  70    "typeOfProbability",
  71    "thresholdLowerLimit",
  72    "thresholdUpperLimit"
  73]
  74
  75AVAILABLE_NON_GEO_COORDS = [
  76    "duration",
  77    "leadTime",
  78    "percentileValue",
  79    "perturbationNumber",
  80    "refDate",
  81    "thresholdLowerLimit",
  82    "thresholdUpperLimit",
  83    "valueOfFirstFixedSurface",
  84    "valueOfSecondFixedSurface",
  85    "aerosolType",
  86    "scaledValueOfFirstWavelength",
  87    "scaledValueOfSecondWavelength",
  88    "scaledValueOfCentralWaveNumber",
  89    "scaledValueOfFirstSize",
  90    "scaledValueOfSecondSize"
  91]
  92"""Available non-geographic coordinate names."""
  93
  94AVAILABLE_NON_GEO_DIMS = [
  95    "duration",
  96    "leadTime",
  97    "percentileValue",
  98    "perturbationNumber",
  99    "refDate",
 100    "threshold",
 101    "level",
 102]
 103"""Available non-geographic dimension names."""
 104
 105# Lookup table to define surface types that should be parsed as vertical coordinates
 106VERTICAL_COORDINATE_SURFACES = [
 107    "Ground or Water Surface",
 108    "Isothermal Level",
 109    "Specified radius from the centre of the Sun",
 110    "Isobaric Surface",
 111    "Mean Sea Level",
 112    "Specific Altitude Above Mean Sea Level",
 113    "Specified Height Level Above Ground",
 114    "Sigma Level",
 115    "Hybrid Level",
 116    "Depth Below Land Surface",
 117    "Isentropic (theta) Level",
 118    "Level at Specified Pressure Difference from Ground to Level",
 119    "Potential Vorticity Surface",
 120    "Eta Level",
 121    "Logarithmic Hybrid Level",
 122    "Sigma height level",
 123    "Hybrid Height Level",
 124    "Hybrid Pressure Level",
 125    "Soil level",
 126    "Sea-ice level",
 127    "Depth Below Sea Level",
 128    "Depth Below Water Surface",
 129    "Ocean Model Level",
 130    "Ocean level defined by water density (sigma-theta) difference from near-surface to level",
 131    "Ocean level defined by water potential temperature difference from near-surface to level",
 132    "Ocean level defined by vertical eddy diffusivity difference from near-surface to level",
 133    "Ocean level defined by water density (rho) difference from near-surface to level"
 134]
 135"""
 136Lookup table to define surface types that should be parsed as vertical coordinates
 137when `data_model="nws-viz"`.
 138"""
 139
 140def parse_data_model(ds, data_model):
 141    """
 142    Normalize a GRIB2-derived Dataset to a target data model (currently ``"nws-viz"``).
 143
 144    When ``data_model == "nws-viz"``, this function converts coordinate and
 145    variable names to snake_case, derives CF-like metadata, promotes select
 146    GRIB-derived quantities to coordinates, optionally swaps dimensions, and
 147    standardizes units/attributes. If ``data_model`` is anything else, the
 148    input dataset is returned unchanged.
 149
 150    Parameters
 151    ----------
 152    ds : xarray.Dataset
 153        GRIB2-derived dataset whose variables and attributes follow the
 154        conventions emitted by ``grib2io``. Expected to contain GRIB-related
 155        attributes such as ``typeOfFirstFixedSurface``,
 156        ``typeOfSecondFixedSurface``, and (for probabilistic variables)
 157        ``typeOfProbability``.
 158    data_model : str
 159        Target data model name. Only the value ``"nws-viz"`` triggers
 160        transformations.
 161
 162    Returns
 163    -------
 164    xarray.Dataset
 165        A new dataset with:
 166        * Selected coordinates renamed:
 167          ``refDate -> forecast_reference_time``,
 168          ``leadTime -> lead_time``,
 169          ``validDate -> time``,
 170          ``percentileValue -> percentile``,
 171          ``thresholdLowerLimit -> threshold_lower_limit``,
 172          ``thresholdUpperLimit -> threshold_upper_limit``.
 173        * Vertical coordinates derived from
 174          ``valueOfFirstFixedSurface`` / ``valueOfSecondFixedSurface`` and their
 175          corresponding ``typeOf*FixedSurface`` definitions. New coordinate
 176          names are generated from the surface definition (lowercased, spaces
 177          to underscores, punctuation removed). If the name already exists, a
 178          ``"_2"`` suffix is appended.
 179        * Possible dimension swaps:
 180          ``level -> <derived_vertical_coord>`` when present; and for
 181          probabilistic variables, ``threshold -> threshold_lower_limit`` or
 182          ``threshold -> threshold_upper_limit`` when
 183          ``typeOfProbability`` indicates the appropriate semantics.
 184        * Variable names lowercased; dataset- and variable-level attributes
 185          converted to snake_case (except GRIB section attributes which are
 186          normalized to ``grib...``).
 187        * CF-adjacent metadata populated: ``standard_name`` and
 188          ``cell_methods`` are set via the shortname→CF lookup table.
 189        * Percent units normalized from ``"%"`` to ``"percent"`` on coordinates.
 190        * For precipitation type (``PTYPE``) thresholds, numeric codes are
 191          decoded to strings (GRIB2 Table 4.201) in relevant attrs/coords.
 192
 193    Notes
 194    -----
 195    - Precipitation type decoding uses GRIB2 Table 4.201 via
 196      ``tables.get_value_from_table(code, "4.201")`` and returns a NumPy
 197      array with ``np.dtypes.StringDType``.
 198    - CF-related lookups are performed using
 199      ``tables.get_table("shortname_to_cf")``.
 200    - Vertical coordinate surface names are validated against
 201      ``VERTICAL_COORDINATE_SURFACES`` before promotion to coordinates.
 202
 203    Warnings
 204    --------
 205    This function assumes the presence of certain GRIB-derived attributes on the
 206    first data variable (e.g., ``typeOfFirstFixedSurface``,
 207    ``typeOfSecondFixedSurface``, and possibly ``typeOfProbability``).
 208    If these are absent or malformed, errors (e.g., ``KeyError``) may occur.
 209
 210    Examples
 211    --------
 212    >>> ds2 = parse_data_model(ds, "nws-viz")
 213    >>> list(ds2.coords)
 214    ['forecast_reference_time', 'lead_time', 'time', 'percentile', ...]
 215    """
 216
 217    def _decode_ptype(values):
 218        """
 219        Decode precipitation type values into human-readable strings.
 220
 221        Uses GRIB2 Table 4.201 to map numeric codes to precipitation type descriptions.
 222
 223        Parameters
 224        ----------
 225        values : array_like
 226            Array of numeric precipitation type codes (e.g., integers or floats).
 227            Each value corresponds to a GRIB2 Table 4.201 precipitation type code.
 228
 229        Returns
 230        -------
 231        numpy.ndarray
 232            Array of decoded precipitation type strings with
 233            NumPy’s flexible string data type (`np.dtypes.StringDType`).
 234        """
 235        results = []
 236        for val in values:
 237            # Convert each numeric code to string and look up in Table 4.201
 238            results.append(str(tables.get_value_from_table(str(int(val)), '4.201')))
 239
 240        # Return array of strings using numpy's string data type
 241        return np.array(results, dtype=np.dtypes.StringDType)
 242
 243    # convert coordinates and attributes to CF if requested
 244    if data_model == 'nws-viz':
 245
 246        # define regex to convert to snake case
 247        pattern = re.compile(r'(?<!^)(?=[A-Z])')
 248
 249        # check for coordinates and rename
 250        for coord in ds.coords:
 251            if coord == 'refDate':
 252                ds = ds.rename({'refDate': 'forecast_reference_time'})
 253
 254            elif coord == 'leadTime':
 255                ds = ds.rename({'leadTime': 'lead_time'})
 256
 257            elif coord == 'validDate':
 258                ds = ds.rename({'validDate': 'time'})
 259
 260            elif coord == 'percentileValue':
 261                ds = ds.rename({'percentileValue': 'percentile'})
 262
 263            elif coord == 'thresholdLowerLimit':
 264                ds = ds.rename({'thresholdLowerLimit': 'threshold_lower_limit'})
 265                ds['threshold_lower_limit'].attrs['long_name'] = 'Threshold Lower Limit'
 266                ds['threshold_lower_limit'].attrs['units'] = ds[list(ds.data_vars.keys())[0]].attrs['units']
 267
 268                if 'PTYPE' in ds.data_vars:
 269                    ds['threshold_lower_limit'] = xr.apply_ufunc(_decode_ptype, ds['threshold_lower_limit'])
 270
 271                # check if thresholdLowerLimit should be a dimension coordinate
 272                if 'threshold' in ds.dims:
 273                    var_key = list(ds.data_vars.keys())[0]
 274                    prob_types = [
 275                        'Probability of event below lower limit',
 276                        'Probability of event above lower limit',
 277                        'Probability of event equal to lower limit',
 278                        'Probability of event between upper and lower limits (the range includes lower limit but not the upper limit)'
 279                    ]
 280                    if ds[var_key].attrs['typeOfProbability'] in prob_types:
 281                        ds = ds.swap_dims({'threshold': 'threshold_lower_limit'})
 282
 283            elif coord == 'thresholdUpperLimit':
 284                ds = ds.rename({'thresholdUpperLimit': 'threshold_upper_limit'})
 285                ds['threshold_upper_limit'].attrs['long_name'] = 'Threshold Upper Limit'
 286                ds['threshold_upper_limit'].attrs['units'] = ds[list(ds.data_vars.keys())[0]].attrs['units']
 287
 288                if 'PTYPE' in ds.data_vars:
 289                    ds['threshold_upper_limit'] = xr.apply_ufunc(_decode_ptype, ds['threshold_upper_limit'])
 290
 291                if 'threshold' in ds.dims:
 292                    var_key = list(ds.data_vars.keys())[0]
 293                    prob_types = [
 294                        'Probability of event below upper limit',
 295                        'Probability of event above upper limit'
 296                    ]
 297                    if ds[var_key].attrs['typeOfProbability'] in prob_types:
 298                        ds = ds.swap_dims({'threshold': 'threshold_upper_limit'})
 299
 300            # If the dataset has valueOfFirstFixedSurface as a coordinate
 301            elif coord == 'valueOfFirstFixedSurface':
 302                # Get the valueOfFirstFixedSurface coordinate
 303                da = ds.valueOfFirstFixedSurface
 304
 305                # Get the definition and units from typeOfFirstFixedSurface
 306                var_key = list(ds.data_vars.keys())[0]
 307                definition, units = ds[var_key].attrs['typeOfFirstFixedSurface']
 308
 309                if definition in VERTICAL_COORDINATE_SURFACES:
 310                    # Convert definition to lowercase and replace spaces with underscores
 311                    key = definition.lower().replace(' ', '_')
 312
 313                    # remove special characters
 314                    key = re.sub(r'[^a-z0-9_]', '', key)
 315
 316                    # Add units and grib_name attributes
 317                    da.attrs['units'] = units
 318                    da.attrs['grib_name'] = ['valueOfFirstFixedSurface', 'typeOfFirstFixedSurface']
 319
 320                    # Assign the coordinate with the new key name
 321                    ds = ds.assign_coords({key: da})
 322
 323                    # If valueOfFirstFixedSurface is a dimension, swap it with the new key
 324                    if 'level' in ds.dims:
 325                        ds = ds.swap_dims({"level": key})
 326
 327                # Remove the original coordinates
 328                del ds['valueOfFirstFixedSurface']
 329
 330            # If the dataset has valueOfSecondFixedSurface as a coordinate
 331            elif coord == 'valueOfSecondFixedSurface':
 332                # Get the valueOfSecondFixedSurface coordinate
 333                da = ds.valueOfSecondFixedSurface
 334
 335                # Get the definition and units from typeOfSecondFixedSurface
 336                var_key = list(ds.data_vars.keys())[0]
 337                definition, units = ds[var_key].attrs['typeOfSecondFixedSurface']
 338
 339                if definition in VERTICAL_COORDINATE_SURFACES:
 340                    # Convert definition to lowercase and replace spaces with underscores
 341                    key = definition.lower().replace(' ', '_')
 342
 343                    # remove special characters
 344                    key = re.sub(r'[^a-z0-9_]', '', key)
 345
 346                    # check if key is already in coords
 347                    if key in ds.coords:
 348                        key = key + '_2'
 349
 350                    # Add units and grib_name attributes
 351                    da.attrs['units'] = units
 352                    da.attrs['grib_name'] = ['valueOfSecondFixedSurface', 'typeOfSecondFixedSurface']
 353
 354                    # Assign the coordinate with the new key name
 355                    ds = ds.assign_coords({key: da})
 356
 357                # Remove the original coordinates
 358                del ds['valueOfSecondFixedSurface']
 359            else:
 360                # change coord name to snake case
 361                new_coord_name = pattern.sub('_', coord).lower()
 362                ds = ds.rename({coord: new_coord_name})
 363
 364        # convert all attributes and variable names to snake case
 365        for var in ds.data_vars:
 366            da = ds[var]
 367            record = tables.get_table('shortname_to_cf').get(da.name)
 368            da.attrs['standard_name'] = 'unknown' if record is None else record['cf_standard_name']
 369            da.attrs['cell_methods'] = 'unknown' if record is None else record['cf_cell_methods']
 370
 371            ds[var] = da
 372
 373            # rename variable
 374            new_var_name = var.lower()
 375            ds = ds.rename({var: new_var_name})
 376
 377            # remove attr for typeOfFirstFixedSurface (applied as coordinate above)
 378            if 'typeOfFirstFixedSurface' in ds[new_var_name].attrs:
 379                definition, units = ds[new_var_name].attrs['typeOfFirstFixedSurface']
 380                ds[new_var_name].attrs['typeOfFirstFixedSurface'] = f'{definition} ({units})'
 381
 382            if 'typeOfSecondFixedSurface' in ds[new_var_name].attrs:
 383                definition, units = ds[new_var_name].attrs['typeOfSecondFixedSurface']
 384                ds[new_var_name].attrs['typeOfSecondFixedSurface'] = f'{definition} ({units})'
 385
 386            ds[new_var_name].attrs.pop('percentileValue', None)
 387
 388            if 'threshold_lower_limit' in ds.coords:
 389                ds[new_var_name].attrs.pop('thresholdLowerLimit', None)
 390
 391            if 'threshold_upper_limit' in ds.coords:
 392                ds[new_var_name].attrs.pop('thresholdUpperLimit', None)
 393
 394            for attr in list(ds[new_var_name].attrs.keys()):
 395                # skip grib section attrs
 396                if 'GRIB2IO_section' in attr:
 397                    # replace GRIB2IO with grib in attr
 398                    new_attr_name = attr.replace('GRIB2IO', 'grib')
 399                else:
 400                    # change attr name to snake case
 401                    new_attr_name = pattern.sub('_', attr).lower()
 402
 403                # update new attr name for specific CF names
 404                if new_attr_name == 'full_name':
 405                    new_attr_name = 'long_name'
 406
 407                # change % to percent
 408                if attr == 'units' and ds[new_var_name].attrs[attr] == '%':
 409                    ds[new_var_name].attrs[attr] = 'percent'
 410
 411                if new_var_name == 'ptype' and 'threshold' in new_attr_name:
 412                    value = ds[new_var_name].attrs.pop(attr)
 413                    ds[new_var_name].attrs[attr] = _decode_ptype(value)
 414                else:
 415                    # change attr name in attrs
 416                    ds[new_var_name].attrs[new_attr_name] = ds[new_var_name].attrs.pop(attr)
 417
 418
 419        # change dataset attrs to snake case
 420        for attr in list(ds.attrs.keys()):
 421            # change attr name to snake case
 422            new_attr_name = pattern.sub('_', attr).lower()
 423
 424            # change attr name in attrs
 425            ds.attrs[new_attr_name] = ds.attrs.pop(attr)
 426
 427        # change % to percent
 428        for coord in ds.coords:
 429            if 'units' in ds[coord].attrs and ds[coord].attrs['units'] == '%':
 430                ds[coord].attrs['units'] = 'percent'
 431
 432    return ds
 433
 434
 435class GribBackendEntrypoint(BackendEntrypoint):
 436    """
 437    xarray backend engine entrypoint for opening and decoding grib2 files.
 438
 439    .. warning::
 440
 441       This backend is experimental and the API/behavior may change without
 442       backward compatibility.
 443    """
 444
 445    def open_dataset(
 446        self,
 447        filename,
 448        *,
 449        drop_variables=None,
 450        save_index=True,
 451        filters: typing.Mapping[str, typing.Any] = dict(),
 452        data_model=None
 453    ):
 454        """
 455        Read and parse metadata from grib file.
 456
 457        Parameters
 458        ----------
 459        filename
 460            GRIB2 file to be opened.
 461        filters
 462            Filter GRIB2 messages to single hypercube. Dict keys can be any
 463            GRIB2 metadata attribute name.
 464        data_model
 465            Parse GRIB metadata following a defined data model comvention.
 466
 467        Returns
 468        -------
 469        open_dataset
 470            Xarray dataset of grib2 messages.
 471        """
 472        with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f:
 473            file_index = pd.DataFrame(f._index)
 474            file_index = file_index.assign(msg=msgs_from_index(f._index))
 475
 476        # parse grib2io _index to dataframe and acquire non-geo possible dims
 477        # (scalar coord when not dim due to squeeze) parse_grib_index applies
 478        # filters to index and expands metadata based on product definition
 479        # template number
 480        file_index, dim_coords, attrs, coord_attrs = parse_grib_index(file_index, filters)
 481
 482        # Divide up records by variable
 483        frames, cube, extra_geo = make_variables(file_index, filename, dim_coords)  # have this return var_attrs
 484
 485        # return empty dataset if no data
 486        if frames is None:
 487            return xr.Dataset()
 488
 489        # create dataframe and add datarrays without any coords
 490        ds = xr.Dataset()
 491        for var_df in frames:
 492            da = build_da_without_coords(var_df, cube, filename, attrs)
 493            ds[da.name] = da
 494
 495        # add coords and dataset meta
 496        ds = assign_xr_meta(ds, frames, cube, dim_coords, extra_geo, coord_attrs)
 497
 498        if data_model is not None:
 499            ds = parse_data_model(ds, data_model)
 500
 501        # assign attributes
 502        ds.attrs['engine'] = 'grib2io'
 503
 504        return ds
 505
 506    def open_datatree(
 507        self,
 508        filename,
 509        *,
 510        drop_variables=None,
 511        save_index=True,
 512        filters: typing.Mapping[str, typing.Any] = None,
 513        stack_vertical: bool = False,
 514    ):
 515        """
 516        Open a GRIB2 file as an xarray DataTree.
 517
 518        Parameters
 519        ----------
 520        filename : str
 521            Path to the GRIB2 file.
 522        drop_variables : list, optional
 523            List of variables to exclude.
 524        filters : dict, optional
 525            Filter criteria for GRIB2 messages.
 526        stack_vertical : bool, optional
 527            If True, organize the tree with vertical layers stacked in a single dataset.
 528
 529        Returns
 530        -------
 531        xarray.DataTree
 532            A hierarchical DataTree representation of the GRIB2 data.
 533        """
 534        if not _HAS_DATATREE:
 535            raise ImportError("xarray version does not support DataTree functionality.")
 536
 537        if filters is None:
 538            filters = {}
 539
 540        # Open the file without any filters first to get all messages
 541        with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f:
 542            file_index = pd.DataFrame(f._index)
 543            file_index = file_index.assign(msg=msgs_from_index(f._index))
 544
 545        # Build tree structure from GRIB messages with specified options
 546        tree = build_datatree_from_grib(filename, file_index, filters, stack_vertical=stack_vertical)
 547
 548        # Put warning here so it is the last message from likely other Xarray warnings.
 549        warnings.warn(
 550            "grib2io’s xarray backend DataTree support is experimental. "
 551            "The DataTree structure or attributes may change in future releases.",
 552        UserWarning,
 553        stacklevel=2,
 554        )
 555
 556        return tree
 557
 558
 559class GribBackendArray(BackendArray):
 560
 561    def __init__(self, array, lock):
 562        self.array = array
 563        self.shape = array.shape
 564        self.dtype = np.dtype(array.dtype)
 565        self.lock = lock
 566
 567    def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayLike:
 568        return xr.core.indexing.explicit_indexing_adapter(
 569            key,
 570            self.shape,
 571            indexing.IndexingSupport.BASIC,
 572            self._raw_getitem,
 573        )
 574
 575    def _raw_getitem(self, key: tuple):
 576        """Implement thread safe access to data on disk."""
 577        with self.lock:
 578            return self.array[key]
 579
 580
 581def exclusive_slice_to_inclusive(item: slice):
 582    """
 583    Convert a slice with exclusive stop to an inclusive slice.
 584
 585    If the slice has a step, the stop is reduced by the step, so that both
 586    interpretations would yield the same result.
 587
 588    The means that [start, stop) is converted to [start, stop - step].
 589
 590    Parameters
 591    ----------
 592    item
 593        The slice to convert.
 594
 595    Returns
 596    -------
 597    slice
 598        The converted slice.
 599    """
 600    # return the None slice
 601    if item.start is None and item.stop is None and item.step is None:
 602        return item
 603    if not isinstance(item, slice):
 604        raise ValueError(f'item must be a slice; it was of type {type(item)}')
 605    # if step is None, it's one
 606    step = 1 if item.step is None else item.step
 607    if item.stop < item.start or step < 1:
 608        raise ValueError(f'slice {item} not accounted for')
 609    # handle case where slice has one item
 610    if abs(item.stop - item.start) == step:
 611        return [item.start]
 612    # other cases require reducing the stop by the step
 613    s = slice(item.start, item.stop - step, step)
 614    return s
 615
 616
 617class Validator:
 618    def __set_name__(self, owner, name):
 619        self.private_name = f'_{name}'
 620        self.name = name
 621
 622    def __get__(self, obj, objtype=None):
 623        try:
 624            value = getattr(obj, self.private_name)
 625        except AttributeError:
 626            value = None
 627        return value
 628
 629
 630class PdIndex(Validator):
 631
 632    def __set__(self, obj, value):
 633        try:
 634            value = pd.Index(value)
 635        except TypeError:
 636            value = pd.Index([value])
 637        setattr(obj, self.private_name, value)
 638
 639
 640def _asarray_tuplesafe(values):
 641    """
 642    Convert values to a numpy array of at most 1-dimension and preserve tuples.
 643
 644    Adapted from pandas.core.common._asarray_tuplesafe
 645    """
 646    if isinstance(values, tuple):
 647        result = np.empty(1, dtype=object)
 648        result[0] = values
 649    else:
 650        result = np.asarray(values)
 651        if result.ndim == 2:
 652            result = np.empty(len(values), dtype=object)
 653            result[:] = values
 654
 655    return result
 656
 657
 658def array_safe_eq(a, b) -> bool:
 659    """Check if a and b are equal, even if they are numpy arrays."""
 660    if a is b:
 661        return True
 662    if hasattr(a, 'equals'):
 663        return a.equals(b)
 664    if hasattr(a, 'all') and hasattr(b, 'all'):
 665        return a.shape == b.shape and (a == b).all()
 666    if hasattr(a, 'all') or hasattr(b, 'all'):
 667        return False
 668    try:
 669        return a == b
 670    except TypeError:
 671        return NotImplementedError
 672
 673
 674def dc_eq(dc1, dc2) -> bool:
 675    """Check if two dataclasses which hold numpy arrays are equal."""
 676    if dc1 is dc2:
 677        return True
 678    if dc1.__class__ is not dc2.__class__:
 679        return NotImplementedError
 680    t1 = astuple(dc1)
 681    t2 = astuple(dc2)
 682    return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))
 683
 684
 685def coords_from_cube(cube) -> typing.Dict[str, xr.Variable]:
 686    keys = list(cube.keys())
 687    keys.remove('x')
 688    keys.remove('y')
 689    coords = dict()
 690    for k in keys:
 691        if k is not None:
 692            if len(cube[k]) > 1:
 693                coords[k] = xr.Variable(dims=k, data=cube[k], attrs=dict(grib_name=k))
 694            elif len(cube[k]) == 1:
 695                coords[k] = xr.Variable(dims=tuple(), data=cube[k][0], attrs=dict(grib_name=k))
 696    return coords
 697
 698
 699@dataclass
 700class OnDiskArray:
 701    file_name: str
 702    index: pd.DataFrame = field(repr=False)
 703    cube: dict = field(repr=False)
 704    shape: typing.Tuple[int, ...] = field(init=False)
 705    ndim: int = field(init=False)
 706    geo_ndim: int = field(init=False)
 707    dtype = 'float32'
 708
 709    def __post_init__(self):
 710        # multiple grids not allowed so can just use first
 711        geo_shape = (self.index.iloc[0].ny, self.index.iloc[0].nx)
 712
 713        self.geo_shape = geo_shape
 714        self.geo_ndim = len(geo_shape)
 715
 716        if len(self.index) == 1:
 717            self.shape = geo_shape
 718        else:
 719            if self.index.index.nlevels == 1:
 720                self.shape = tuple([len(self.index.index)]) + geo_shape
 721            else:
 722                self.shape = tuple([len(i) for i in self.index.index.levels]) + geo_shape
 723        self.ndim = len(self.shape)
 724
 725        cols = ['msg', 'sectionOffset']
 726        self.index = self.index[cols]
 727
 728    def __getitem__(self, item) -> np.array:
 729        # dimensions not in index are internal to tdlpack records; 2 dims for
 730        # grids; 1 dim for stations
 731
 732        index_slicer = item[:-self.geo_ndim]
 733        # maintain all multindex levels
 734        index_slicer = tuple([[i] if isinstance(i, int) else i for i in index_slicer])
 735
 736        # pandas loc slicing is inclusive, therefore convert slices into
 737        # explicit lists
 738        index_slicer_inclusive = tuple([exclusive_slice_to_inclusive(
 739            i) if isinstance(i, slice) else i for i in index_slicer])
 740
 741        # get records selected by item in new index dataframe
 742        if len(index_slicer_inclusive) == 1:
 743            index = self.index.loc[index_slicer_inclusive]
 744        elif len(index_slicer_inclusive) > 1:
 745            index = self.index.loc[index_slicer_inclusive, :]
 746        else:
 747            index = self.index
 748        index = index.set_index(index.index)
 749
 750        # set miloc to new relative locations in sub array
 751        index['miloc'] = list(
 752            zip(*[index.index.unique(level=dim).get_indexer(index.index.get_level_values(dim)) for dim in index.index.names]))
 753
 754        if len(index_slicer_inclusive) == 1:
 755            array_field_shape = tuple([len(index.index)]) + self.geo_shape
 756        elif len(index_slicer_inclusive) > 1:
 757            array_field_shape = index.index.levshape + self.geo_shape
 758        else:
 759            array_field_shape = self.geo_shape
 760
 761        array_field = np.full(array_field_shape, fill_value=np.nan, dtype="float32")
 762
 763        with open(self.file_name, mode='rb') as filehandle:
 764            for key, row in index.iterrows():
 765
 766                bitmap_offset = None if pd.isna(row['sectionOffset'][6]) else int(row['sectionOffset'][6])
 767                values = _data(filehandle, row.msg, bitmap_offset, row['sectionOffset'][7])
 768
 769                if len(index_slicer_inclusive) >= 1:
 770                    array_field[row.miloc] = values
 771                else:
 772                    array_field = values
 773
 774        # handle geo dim slicing
 775        array_field = array_field[(Ellipsis,) + item[-self.geo_ndim:]]
 776
 777        # squeeze array dimensions expressed as integer
 778        for i, it in reversed(list(enumerate(item[: -self.geo_ndim]))):
 779            if isinstance(it, int):
 780                array_field = array_field[(slice(None, None, None),) * i + (0,)]
 781
 782        return array_field
 783
 784
 785def dims_to_shape(d) -> tuple:
 786    if 'nx' in d:
 787        t = (d['ny'], d['nx'])
 788    else:
 789        t = (d['nsta'],)
 790    return t
 791
 792
 793def filter_index(index, k, v):
 794    if isinstance(v, slice):
 795        index = index.set_index(k)
 796        index = index.loc[v]
 797        index = index.reset_index()
 798    else:
 799        label = (
 800            v
 801            if getattr(v, "ndim", 1) > 1  # vectorized-indexing
 802            else _asarray_tuplesafe(v)
 803        )
 804        if label.ndim == 0:
 805            # see https://github.com/pydata/xarray/pull/4292 for details
 806            label_value = label[()] if label.dtype.kind in "mM" else label.item()
 807            try:
 808                indexer = pd.Index(index[k]).get_loc(label_value)
 809                if isinstance(indexer, int):
 810                    index = index.iloc[[indexer]]
 811                else:
 812                    index = index.iloc[indexer]
 813            except KeyError:
 814                index = index.iloc[[]]
 815        else:
 816            indexer = pd.Index(index[k]).get_indexer_for(np.ravel(v))
 817            index = index.iloc[indexer[indexer >= 0]]
 818
 819    return index
 820
 821
 822def parse_grib_index(
 823    index: pd.DataFrame,
 824    filters: typing.Mapping[str, typing.Any] = dict(),
 825):
 826    """
 827    Apply filters.
 828
 829    Evaluate remaining dimensions based on pdtn and parse each out.
 830
 831    Parameters
 832    ----------
 833    index
 834        Pandas DataFrame containing the GRIB2 message index.
 835    filters
 836        Filter GRIB2 messages to single hypercube. Dict keys can be any
 837        GRIB2 metadata attribute name.
 838
 839    Returns
 840    -------
 841    index
 842        Modified Pandas DataFrame with added GRIB2 metadata columns.
 843    dim_coords
 844        List of GRIB2 attributes that will be used for coordinates and/or dimensions.
 845    attrs
 846        Dict of metadata attributes (non-coordinates, non-geo)
 847    """
 848
 849    # make a copy of filters, remove filters as they are applied
 850    filters = copy(filters)
 851
 852    for k, v in filters.items():
 853        if k not in index.columns:
 854            kwarg = {k: index.msg.apply(lambda msg: getattr(msg, k))}
 855            index = index.assign(**kwarg)
 856        # adopt parts of xarray's sel logic  so that filters behave similarly
 857        # allowed to filter to nothing to make empty dataset
 858        index = filter_index(index, k, v)
 859
 860    if len(index) == 0:
 861        return index, list()
 862
 863    dim_coords = dict()  # key=name of dim, value=list of coord names
 864    attrs = dict()
 865    coord_attrs = dict()
 866
 867    # expand index
 868    index = index.assign(shortName=index.msg.apply(lambda msg: msg.shortName))
 869    index = index.assign(nx=index.msg.apply(lambda msg: msg.nx))
 870    index = index.assign(ny=index.msg.apply(lambda msg: msg.ny))
 871    index = index.astype({'ny': 'int', 'nx': 'int'})
 872
 873    # apply common filters(to all definition templates) to reduce dataset to
 874    # single cube
 875    # ensure only one of each of the below exists after filters applied
 876    required_uniques = [
 877        "productDefinitionTemplateNumber",
 878        "typeOfGeneratingProcess",
 879        "typeOfFirstFixedSurface",
 880        "typeOfSecondFixedSurface",
 881    ]
 882
 883    def meta_check(index, attrs, meta):
 884        """
 885        add meta to the datframe index
 886        check that there is a single type
 887        add the type to attrs
 888
 889        returns index, attrs
 890        """
 891        index = index.assign(**{meta: index.msg.apply(lambda msg: getattr(msg, meta))})
 892
 893        unique = index[meta].unique()
 894        if len(index[meta].unique()) > 1:
 895            raise ValueError(f'filter to a single {meta}; found: {[str(i) for i in unique]}')
 896        value = unique.item()
 897        if type(value) == grib2io.templates.Grib2Metadata:
 898            value = value.definition
 899
 900        # None is returned if no value found,
 901        # check and change to string None
 902        if value is None:
 903            value = 'None'
 904
 905        attrs[meta] = value
 906        return index, attrs
 907
 908    for meta in required_uniques:
 909        index, attrs = meta_check(index, attrs, meta)
 910
 911    pdtn = index.productDefinitionTemplateNumber.iloc[0].value
 912
 913    # determine which non geo dimensions can be created from data by this point
 914    # the index is filtered down to a single type for all required_uniques
 915
 916    # Dim Name     # matching dim_name for using this data as index coordinate
 917    dim_coords["refDate"] = ["refDate"]
 918    coord_attrs["refDate"] = dict(standard_name="forecast_reference_time")
 919#   dim_coords["refDate"] = ["refDate", "hour"] # non dim name matching items in list are used as non-index coordinates
 920
 921    dim_coords["leadTime"] = ["leadTime"]
 922    coord_attrs["leadTime"] = dict(standard_name="forecast_period")
 923
 924    if 'valueOfFirstFixedSurface' not in index.columns:
 925        index = index.assign(valueOfFirstFixedSurface=index.msg.apply(lambda msg: msg.valueOfFirstFixedSurface))
 926    if 'valueOfsecondFixedSurface' not in index.columns:
 927        index = index.assign(valueOfSecondFixedSurface=index.msg.apply(lambda msg: msg.valueOfSecondFixedSurface))
 928
 929    # dim name api change, user could run ds = ds.swap_dims(fixedSurface="valueOfFirstFixedSurface")
 930    index = index.assign(level=list(zip(index['valueOfFirstFixedSurface'], index['valueOfSecondFixedSurface'])))
 931#   index = index.assign(level=index.msg.apply(lambda msg: msg.level))
 932    # lack of "level" indeicates don't create extra index coordinate "level"
 933    dim_coords["level"] = ["valueOfFirstFixedSurface", "valueOfSecondFixedSurface"]
 934
 935    # logic for parsing possible dims from specific product definition section
 936
 937    if pdtn in {5, 9}:
 938
 939        # Probability forecasts at a horizontal level or in a horizontal layer
 940        # in a continuous or non-continuous time interval.  (see Template
 941        # 4.9)
 942        #       AVAILABLE_THRESHOLD = {
 943        #           0: {'has_lower': True, 'has_upper': False},
 944        #           1: {'has_lower': False, 'has_upper': True},
 945        #           2: {'has_lower': True, 'has_upper': True},
 946        #           3: {'has_lower': True, 'has_upper': False},
 947        #           4: {'has_lower': False, 'has_upper': True},
 948        #           5: {'has_lower': True, 'has_upper': False},
 949        #       }
 950
 951        index, attrs = meta_check(index, attrs, "typeOfProbability")
 952        if 'thresholdLowerLimit' not in index.columns:
 953            index = index.assign(thresholdLowerLimit=index.msg.apply(lambda msg: msg.thresholdLowerLimit))
 954        if 'thresholdUpperLimit' not in index.columns:
 955            index = index.assign(thresholdUpperLimit=index.msg.apply(lambda msg: msg.thresholdUpperLimit))
 956        if 'threshold' not in index.columns:
 957            # using composite of lower and upper, but could use threshold string from grib2io as long as that is unique and based on lower and upper
 958            index = index.assign(threshold=list(zip(index['thresholdLowerLimit'], index['thresholdUpperLimit'])))
 959#           index = index.assign(threshold = index.msg.apply(lambda msg: msg.threshold))
 960
 961        # ommiting threshold results in no index being assigned for this possible dim
 962        dim_coords["threshold"] = ["thresholdLowerLimit", "thresholdUpperLimit"]
 963
 964    if pdtn in {6, 10}:
 965
 966        # Percentile forecasts at a horizontal level or in a horizontal layer
 967        # in a continuous or non-continuous time interval.  (see Template
 968        # 4.10)
 969        dim_coords["percentileValue"] = ["percentileValue"]
 970        coord_attrs["percentileValue"] = dict(long_name='percentile', units='percent')
 971
 972    if pdtn in {8, 9, 10, 11, 12, 13, 14, 42, 43, 45, 46, 47, 61, 62, 63, 67, 68, 72, 73, 78, 79, 82, 83, 84, 85, 87, 91}:
 973        dim_coords["duration"] = ["duration"]
 974
 975    if pdtn in {1, 11, 33, 34, 41, 43, 45, 47, 49, 54, 56, 58, 59, 63, 68, 77, 79, 81, 83, 84, 85, 92}:
 976        dim_coords["perturbationNumber"] = ["perturbationNumber"]
 977
 978    if pdtn in {2,3,4,12,13,14}:
 979        index, attrs = meta_check(index, attrs, 'typeOfDerivedForecast')
 980
 981    if pdtn in {8,15,42,46,62,67,72,78,82,1001,1002,1100,1101}:
 982        index, attrs = meta_check(index, attrs, 'statisticalProcess')
 983
 984    # Finish logic by pdtn
 985
 986    for k, v in dim_coords.items():
 987        for meta in v:
 988            if meta not in index.columns:
 989                index = index.assign(**{meta: index.msg.apply(lambda msg: getattr(msg, meta))})
 990
 991    return index, dim_coords, attrs, coord_attrs
 992
 993
 994# Custom open_datatree function to open grib files as DataTree
 995def open_datatree(filename, *, filters: typing.Mapping[str, typing.Any] = None, engine="grib2io"):
 996    """
 997    Open a GRIB2 file as an xarray DataTree.
 998
 999    Parameters
1000    ----------
1001    filename : str
1002        Path to the GRIB2 file.
1003    filters : dict, optional
1004        Filter criteria for GRIB2 messages.
1005    engine : str, optional
1006        Engine to use for opening the file, defaults to "grib2io".
1007
1008    Returns
1009    -------
1010    xarray.DataTree
1011        A hierarchical DataTree representation of the GRIB2 data.
1012    """
1013    if not _HAS_DATATREE:
1014        raise ImportError("xarray version does not support DataTree functionality.")
1015
1016    if filters is None:
1017        filters = {}
1018
1019    # Open the file without any filters first to get all messages
1020    with grib2io.open(filename, _xarray_backend=True) as f:
1021        file_index = pd.DataFrame(f._index)
1022
1023    # Create a DataTree root
1024    tree = xr.DataTree()
1025
1026    # Build tree structure from GRIB messages
1027    return build_datatree_from_grib(filename, file_index, filters)
1028
1029
1030def build_da_without_coords(index, cube, filename, attrs) -> xr.DataArray:
1031    """
1032    Build a DataArray without coordinates from a cube of grib2 messages.
1033
1034    Parameters
1035    ----------
1036    index
1037        Index of cube.
1038    cube
1039        Cube of grib2 messages.
1040    filename
1041        Filename of grib2 file
1042    add_grib_section_attrs
1043        Include grib section arrays as dataArray attributes
1044
1045    Returns
1046    -------
1047    DataArray
1048        DataArray without coordinates
1049    """
1050
1051    dim_names = [k for k in cube.keys() if cube[k] is not None and len(cube[k]) > 1]
1052    constant_meta_names = [k for k in cube.keys() if cube[k] is None]
1053    dims = {k: len(cube[k]) for k in dim_names}
1054
1055    # guard against bad datarrays being formed
1056    dims_total = 1
1057    dims_to_filter = []
1058    for dim_name, dim_len, in dims.items():
1059        if dim_name not in {'x', 'y', 'station'}:
1060            dims_total *= dim_len
1061            dims_to_filter.append(dim_name)
1062
1063    # Check number of GRIB2 message indexed compared to non-X/Y
1064    # dimensions.
1065    if dims_total != len(index):
1066        raise ValueError(
1067            f"DataArray dimensions are not compatible with number of GRIB2 messages; DataArray has {dims_total} "
1068            f"and GRIB2 index has {len(index)}. Consider applying a filter for dimensions: {dims_to_filter}"
1069        )
1070
1071    data = OnDiskArray(filename, index, cube)
1072    lock = _LOCK
1073    data = GribBackendArray(data, lock)
1074    data = indexing.LazilyIndexedArray(data)
1075    if len(dim_names) != len(data.shape):
1076        raise ValueError(
1077            "different number of dimensions on data "
1078            f"and dims: {len(data.shape)} vs {len(dim_names)}\n"
1079            "Grib2 messages could not be formed into a data cube; "
1080            "It's possible extra messages exist along a non-accounted for dimension based on PDTN\n"
1081            "It might be possible to get around this by applying a filter on the non-accounted for dimension"
1082        )
1083    da = xr.DataArray(data, dims=dim_names)
1084
1085    da.encoding['original_shape'] = data.shape
1086
1087    da.encoding['preferred_chunks'] = {'y': -1, 'x': -1}
1088    msg1 = index.msg.iloc[0]
1089
1090    # plain language metadata is minimized
1091    # add grib section metadata
1092    da.attrs['GRIB2IO_section0'] = msg1.section0
1093    da.attrs['GRIB2IO_section1'] = msg1.section1
1094    da.attrs['GRIB2IO_section2'] = msg1.section2 if msg1.section2 else []
1095    da.attrs['GRIB2IO_section3'] = msg1.section3
1096    da.attrs['GRIB2IO_section4'] = msg1.section4
1097    da.attrs['GRIB2IO_section5'] = msg1.section5
1098    da.attrs['fullName'] = str(msg1.fullName)
1099    da.attrs['shortName'] = str(msg1.shortName)
1100    da.attrs['units'] = str(msg1.units)
1101    da.attrs['originatingCenter'] = str(msg1.originatingCenter.definition)
1102    da.attrs['originatingSubCenter'] = str(msg1.originatingSubCenter.definition)
1103
1104    # add master table
1105    da.attrs['masterTableInfo'] = str(msg1.masterTableInfo.definition)
1106
1107    da.name = index.shortName.iloc[0]
1108    for meta_name in constant_meta_names:
1109        if meta_name in index.columns:
1110            da.attrs[meta_name] = index[meta_name].iloc[0]
1111
1112    da.attrs.update(attrs)
1113
1114    return da
1115
1116
1117def assign_xr_meta(ds, frames, cube, non_geo_dims, extra_geo, coord_attrs):
1118
1119    # assign coords from the cube; the cube prevents datarrays with
1120    # different shapes
1121    ds = ds.assign_coords(coords_from_cube(cube))
1122    # assign extra index associated coords
1123    df = frames[0]  # use first variable as they all have same shape and index metadata
1124    for dim_name, coord_names in non_geo_dims.items():
1125        retain_index_coord = False
1126        for name in coord_names:
1127            if name == dim_name:
1128                retain_index_coord = True
1129            else:
1130                if ds[dim_name].size == 1:
1131                    # for assigning scalar coords
1132                    coord_data = [df[name].unique().item()]
1133                    ds = ds.assign_coords({name: coord_data}).squeeze()
1134                else:
1135                    # "ValueError: can only convert an array of size 1 to a Python scalar" indicates the coord is not compatible with the index
1136                    coord_data = [df[df.index.get_level_values(f'{dim_name}_ix') == val][name].unique(
1137                    ).item() for val in range(ds[dim_name].size)]
1138                    coord = pd.Index(coord_data, name=dim_name)
1139                    ds = ds.assign_coords({name: coord})
1140        if not retain_index_coord:
1141            ds = ds.drop_vars(dim_name)
1142
1143    # assign extra geo coords
1144    ds = ds.assign_coords(extra_geo)
1145    # add crs data from first grib message to each data variable and the dataset
1146    geo_attrs = {
1147        'crs_wkt': CRS.from_dict(df.msg.iloc[0].projParameters).to_wkt(),
1148        'gridlengthXDirection': df.msg.iloc[0].gridlengthXDirection,
1149        'gridlengthYDirection': df.msg.iloc[0].gridlengthYDirection,
1150        'latitudeFirstGridpoint': df.msg.iloc[0].latitudeFirstGridpoint,
1151        'longitudeFirstGridpoint': df.msg.iloc[0].longitudeFirstGridpoint,
1152    }
1153    for data_var in ds.data_vars:
1154        ds[data_var].attrs.update(geo_attrs)
1155    ds.attrs.update(geo_attrs)
1156
1157    # add coordinate specific attributes
1158    for coord, attrs in coord_attrs.items():
1159        ds[coord].attrs.update(attrs)
1160
1161    # assign valid date coords
1162    try:
1163        ds = ds.assign_coords(dict(validDate=ds.coords['refDate']+ds.coords['leadTime']))
1164        ds.validDate.attrs['standard_name'] = 'time'
1165        ds.validDate.attrs['long_name'] = 'time'
1166    except Exception as e:
1167        warnings.warn(f'could not parse validTime: {e}')
1168
1169    # assign attributes
1170    ds.attrs['engine'] = 'grib2io'
1171
1172    return ds
1173
1174
1175def make_variables(index, f, non_geo_dims, allow_uneven_dims=False):
1176    """
1177    Create an individual dataframe index and cube for each variable.
1178
1179    Parameters
1180    ----------
1181    index
1182        Index of cube.
1183    f
1184        ?
1185    non_geo_dims
1186        Dimensions not associated with the x,y grid
1187    allow_uneven_dims
1188        If True, allows uneven dimensions (used for DataTree creation)
1189
1190    Returns
1191    -------
1192    ordered_frames
1193        List of dataframes, one for each variable.
1194    cube
1195        Cube of grib2 messages.
1196    extra_geo
1197        Extra geographic coordinates.
1198    """
1199    # let shortName determine the variables
1200
1201    # set the index to the name
1202    index = index.set_index('shortName').sort_index()
1203    # return nothing if no data
1204    if index.empty:
1205        return None, None, None
1206
1207    # define the DimCube
1208    dims = copy(non_geo_dims)
1209
1210    ordered_meta = list(non_geo_dims.keys())
1211    cube = None
1212    ordered_frames = list()
1213    for key in index.index.unique():
1214        frame = index.loc[[key]]
1215        frame = frame.reset_index()
1216        # frame is a dataframe with all records for one variable
1217        c = dict()
1218        # for colname in frame.columns:
1219        for colname in ordered_meta:
1220            uniques = pd.Index(frame[colname]).unique()
1221            if len(uniques) > 1:
1222                c[colname] = uniques.sort_values()
1223            else:
1224                c[colname] = [uniques[0]]
1225
1226        dims = [k for k in ordered_meta if len(c[k]) > 1]
1227
1228        for dim in dims:
1229            if frame[dim].value_counts().nunique() > 1 and not allow_uneven_dims:
1230                raise ValueError(
1231                    f'uneven number of grib msgs associated with dimension: {dim}\n unique values for {dim}: {frame[dim].unique()} ')
1232
1233        if len(dims) >= 1:  # dims may be empty if no extra dims on top of x,y
1234            frame = frame.sort_values(dims)
1235            frame = frame.set_index(dims)
1236
1237        if cube:
1238            if cube != c and not allow_uneven_dims:
1239                raise ValueError(f'{cube},\n {c};\n cubes are not the same; filter to a single cube')
1240        else:
1241            cube = c
1242
1243        # miloc is multi-index integer location of msg in nd DataArray
1244        miloc = list(zip(*[frame.index.unique(level=dim).get_indexer(frame.index.get_level_values(dim))
1245                     for dim in dims]))
1246
1247        # set frame multi index
1248        if len(miloc) >= 1:  # miloc will be empty when no extra dims, thus no multiindex
1249            dim_ix = tuple([n+'_ix' for n in dims])
1250            frame = frame.set_index(pd.MultiIndex.from_tuples(miloc, names=dim_ix))
1251
1252        ordered_frames.append(frame)
1253
1254    # no variables
1255    if cube is None:
1256        cube = dict()
1257
1258    # check geography of data and assign to cube
1259    if len(index.ny.unique()) > 1 or len(index.nx.unique()) > 1:
1260        raise ValueError('multiple grids not accommodated')
1261    cube["y"] = range(int(index.ny.iloc[0]))
1262    cube["x"] = range(int(index.nx.iloc[0]))
1263
1264    extra_geo = None
1265    msg = index.msg.iloc[0]
1266
1267    # we want the lat lons; make them via accessing a record; we are assuming
1268    # all records are the same grid because they have the same shape;
1269    # may want a unique grid identifier from grib2io to avoid assuming this
1270    latitude, longitude = msg.latlons()
1271    latitude = xr.DataArray(latitude, dims=['y', 'x'])
1272    latitude.attrs['standard_name'] = 'latitude'
1273    latitude.attrs['units'] = 'degrees_north'
1274    longitude = xr.DataArray(longitude, dims=['y', 'x'])
1275    longitude.attrs['standard_name'] = 'longitude'
1276    longitude.attrs['units'] = 'degrees_east'
1277    extra_geo = dict(latitude=latitude, longitude=longitude)
1278
1279    return ordered_frames, cube, extra_geo
1280
1281
1282def interp_nd(a, *, method, grid_def_in, grid_def_out, method_options=None, num_threads=1):
1283    front_shape = a.shape[:-2]
1284    a = a.reshape(-1, a.shape[-2], a.shape[-1])
1285    a = grib2io.interpolate(a, method, grid_def_in, grid_def_out, method_options=method_options,
1286                            num_threads=num_threads)
1287    a = a.reshape(front_shape + (a.shape[-2], a.shape[-1]))
1288    return a
1289
1290
1291def interp_nd_stations(a, *, method, grid_def_in, lats, lons, method_options=None, num_threads=1):
1292    front_shape = a.shape[:-2]
1293    a = a.reshape(-1, a.shape[-2], a.shape[-1])
1294    a = grib2io.interpolate_to_stations(a, method, grid_def_in, lats, lons, method_options=method_options,
1295                                        num_threads=num_threads)
1296    a = a.reshape(front_shape + (len(lats),))
1297    return a
1298
1299
1300@xr.register_dataset_accessor("grib2io")
1301class Grib2ioDataSet:
1302
1303    def __init__(self, xarray_obj):
1304        self._obj = xarray_obj
1305
1306    def griddef(self):
1307        return Grib2GridDef.from_section3(self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3'])
1308
1309    def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.Dataset:
1310        # see interp method of class Grib2ioDataArray
1311        da = self._obj.to_array()
1312        da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3']
1313        da = da.grib2io.interp(method, grid_def_out, method_options=method_options,
1314                               num_threads=num_threads)
1315        ds = da.to_dataset(dim='variable')
1316        return ds
1317
1318    def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.Dataset:
1319        # see interp_to_stations method of class Grib2ioDataArray
1320        da = self._obj.to_array()
1321        da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3']
1322        da = da.grib2io.interp_to_stations(method, calls, lats, lons, method_options=method_options,
1323                                           num_threads=num_threads)
1324        ds = da.to_dataset(dim='variable')
1325        return ds
1326
1327    def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
1328        """
1329        Write a DataSet to a grib2 file.
1330
1331        Parameters
1332        ----------
1333        filename
1334            Name of the grib2 file to write to.
1335        mode: {"x", "w", "a"}, optional, default="x"
1336            Persistence mode
1337
1338            | mode | Description                       |
1339            | :---:| :---:                             |
1340            | 'x'  | create (fail if exists)           |
1341            | 'w'  | create (overwrite if exists)      |
1342            | 'a'  | append (create if does not exist) |
1343
1344        """
1345        ds = self._obj
1346
1347        for shortName in sorted(ds):
1348            # make a DataArray from the "Data Variables" in the DataSet
1349            da = ds[shortName]
1350
1351            da.grib2io.to_grib2(filename, mode=mode)
1352            mode = "a"
1353
1354    def update_attrs(self, **kwargs):
1355        """
1356        Raises an error because Datasets don't have a .attrs attribute.
1357
1358        Parameters
1359        ----------
1360        attrs
1361            Attributes to update.
1362        """
1363        raise ValueError(
1364            f"Datasets do not have a .attrs attribute; use .grib2io.update_attrs({kwargs}) on a DataArray instead."
1365        )
1366
1367    def subset(self, lats, lons) -> xr.Dataset:
1368        """
1369        Subset the DataSet to a region defined by latitudes and longitudes.
1370
1371        Parameters
1372        ----------
1373        lats
1374            Latitude bounds of the region.
1375        lons
1376            Longitude bounds of the region.
1377
1378        Returns
1379        -------
1380        subset
1381            DataSet subset to the region.
1382        """
1383        ds = self._obj
1384
1385        newds = xr.Dataset()
1386        for shortName in ds:
1387            newds[shortName] = ds[shortName].grib2io.subset(lats, lons).copy()
1388
1389        return newds
1390
1391
1392@xr.register_dataarray_accessor("grib2io")
1393class Grib2ioDataArray:
1394
1395    def __init__(self, xarray_obj):
1396        self._obj = xarray_obj
1397
1398    def griddef(self):
1399        return Grib2GridDef.from_section3(self._obj.attrs['GRIB2IO_section3'])
1400
1401    def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.DataArray:
1402        """
1403        Perform grid spatial interpolation.
1404
1405        Uses the [NCEPLIBS-ip library](https://github.com/NOAA-EMC/NCEPLIBS-ip).
1406
1407        Parameters
1408        ----------
1409        method
1410            Interpolate method to use. This can either be an integer or string
1411            using the following mapping:
1412
1413            | Interpolate Scheme | Integer Value |
1414            | :---:              | :---:         |
1415            | 'bilinear'         | 0             |
1416            | 'bicubic'          | 1             |
1417            | 'neighbor'         | 2             |
1418            | 'budget'           | 3             |
1419            | 'spectral'         | 4             |
1420            | 'neighbor-budget'  | 6             |
1421        grid_def_out
1422            Grib2GridDef object of the output grid.
1423        method_options : list of ints, optional
1424            Interpolation options. See the NCEPLIBS-ip documentation for
1425            more information on how these are used.
1426        num_threads : int, optional
1427            Number of OpenMP threads to use for interpolation. The default
1428            value is 1. If grib2io_interp was not built with OpenMP, then
1429            this keyword argument and value will have no impact.
1430
1431        Returns
1432        -------
1433        interp
1434            DataSet interpolated to new grid definition.  The attribute
1435            GRIB2IO_section3 is replaced with the section3 array from the new
1436            grid definition.
1437        """
1438        da = self._obj
1439        # ensure that y, x are rightmost dims; they should be if opening with
1440        # grib2io engine
1441
1442        # gdtn and gdt is not the entirety of the new s3
1443        npoints = grid_def_out.npoints
1444        s3_new = np.array([0, npoints, 0, 0, grid_def_out.gdtn] + list(grid_def_out.gdt))
1445
1446        # make new lat lons
1447        lats, lons = Grib2Message(section3=s3_new, pdtn=0, drtn=0).grid()
1448        latitude = xr.DataArray(lats, dims=['y', 'x'])
1449        longitude = xr.DataArray(lons, dims=['y', 'x'])
1450
1451        # create new coords
1452        new_coords = dict(da.coords)
1453        del new_coords['latitude']
1454        del new_coords['longitude']
1455        new_coords['longitude'] = longitude
1456        new_coords['latitude'] = latitude
1457
1458        # make grid def in from section3 on da.attrs
1459        grid_def_in = self.griddef()
1460
1461        if da.chunks is None:
1462            data = interp_nd(da.data, method=method, grid_def_in=grid_def_in,
1463                             grid_def_out=grid_def_out,
1464                             method_options=method_options, num_threads=num_threads)
1465        else:
1466            import dask
1467            front_shape = da.shape[:-2]
1468            data = da.data.map_blocks(interp_nd, method=method, grid_def_in=grid_def_in,
1469                                      grid_def_out=grid_def_out, method_options=method_options,
1470                                      chunks=da.chunks[:-2]+latitude.shape, dtype=da.dtype)
1471
1472        new_da = xr.DataArray(data, dims=da.dims, coords=new_coords, attrs=da.attrs)
1473
1474        new_da.attrs['GRIB2IO_section3'] = s3_new
1475        new_da.name = da.name
1476        return new_da
1477
1478    def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.DataArray:
1479        """
1480        Perform spatial interpolation to station points.
1481
1482        Parameters
1483        ----------
1484        method
1485            Interpolate method to use. This can either be an integer or string
1486            using the following mapping:
1487
1488            | Interpolate Scheme | Integer Value |
1489            | :---:              | :---:         |
1490            | 'bilinear'         | 0             |
1491            | 'bicubic'          | 1             |
1492            | 'neighbor'         | 2             |
1493            | 'budget'           | 3             |
1494            | 'spectral'         | 4             |
1495            | 'neighbor-budget'  | 6             |
1496
1497        calls
1498            Station calls used for labeling new station index coordinate
1499        lats
1500            Latitudes of the station points.
1501        lons
1502            Longitudes of the station points.
1503
1504        Returns
1505        -------
1506        interp_to_stations
1507            DataArray interpolated to lat and lon locations and labeled with
1508            dimension and coordinate 'station'. (..., y, x) -> (..., station)
1509        """
1510        da = self._obj
1511        # TODO ensure that y, x are rightmost dims; they should be if opening
1512        # with grib2io engine
1513
1514        calls = np.asarray(calls)
1515        lats = np.asarray(lats)
1516        lons = np.asarray(lons)
1517        latitude = xr.DataArray(lats, dims=['station'])
1518        longitude = xr.DataArray(lons, dims=['station'])
1519
1520        # create new coords
1521        new_coords = dict(da.coords)
1522        del new_coords['latitude']
1523        del new_coords['longitude']
1524        new_coords['longitude'] = longitude
1525        new_coords['latitude'] = latitude
1526        new_coords['station'] = calls
1527
1528        new_dims = da.dims[:-2] + ('station',)
1529
1530        # make grid def in from section3 on da attrs
1531        grid_def_in = self.griddef()
1532
1533        if da.chunks is None:
1534            data = interp_nd_stations(da.data, method=method, grid_def_in=grid_def_in, lats=lats,
1535                                      lons=lons, method_options=method_options, num_threads=num_threads)
1536        else:
1537            import dask
1538            front_shape = da.shape[:-1]
1539            data = da.data.map_blocks(interp_nd_stations, method=method, grid_def_in=grid_def_in,
1540                                      lats=lats, lons=lons, method_options=method_options,
1541                                      drop_axis=-1, chunks=da.chunks[:-2]+latitude.shape,
1542                                      dtype=da.dtype)
1543
1544        new_da = xr.DataArray(data, dims=new_dims, coords=new_coords, attrs=da.attrs)
1545
1546        new_da.name = da.name
1547        return new_da
1548
1549    def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
1550        """
1551        Write a DataArray to a grib2 file.
1552
1553        Parameters
1554        ----------
1555        filename
1556            Name of the grib2 file to write to.
1557        mode: {"x", "w", "a"}, optional, default="x"
1558            Persistence mode
1559
1560            +------+-----------------------------------+
1561            | mode | Description                       |
1562            +======+===================================+
1563            | x    | create (fail if exists)           |
1564            +------+-----------------------------------+
1565            | w    | create (overwrite if exists)      |
1566            +------+-----------------------------------+
1567            | a    | append (create if does not exist) |
1568            +------+-----------------------------------+
1569
1570        """
1571        da = self._obj.copy(deep=True)
1572
1573        coords_keys = sorted(da.coords.keys())
1574        coords_keys = [k for k in coords_keys if k in AVAILABLE_NON_GEO_COORDS]
1575
1576        # If there are dimension coordinates, the DataArray is a hypercube of
1577        # grib2 messages.
1578
1579        # Create `indexes` which is a list of lists of dictionaries for all
1580        # dimension coordinates. Each dictionary key is the dimension
1581        # coordinate name and the value is a list of the dimension coordinate
1582        # values.  This allows for easy iteration over all possible grib2
1583        # messages in the DataArray by using itertools.product.
1584        #
1585        # For example:
1586        # indexes = [
1587        #     [
1588        #         {"leadTime": 9},
1589        #         {"leadTime": 12},
1590        #     ],
1591        #     [
1592        #         {"valueOfFirstFixedSurface": 900},
1593        #         {"valueOfFirstFixedSurface": 925},
1594        #         {"valueOfFirstFixedSurface": 950},
1595        #     ],
1596        # ]
1597
1598        # assign loc indexes to dimensions without indexes for uniform selection by name
1599        loc_indexes = list()
1600        for dim in da.dims:
1601            if dim not in da.indexes:
1602                da = da.assign_coords({dim: range(da[dim].size)})
1603                loc_indexes.append(dim)
1604
1605        indexes = []
1606        for index in [i for i in AVAILABLE_NON_GEO_DIMS if i in da.dims]:
1607            values = da.coords[index].values
1608            if len(values) != len(set(values)):
1609                raise ValueError(
1610                    f"Dimension coordinate '{index}' has duplicate values, but to_grib2 requires unique values to find each GRIB2 message in the DataArray."
1611                )
1612            listeach = [{index: value} for value in sorted(values)]
1613            indexes.append(listeach)
1614
1615        # If `dim_coords` is [], then the DataArray is a single grib2 message and
1616        # itertools.product(*dim_coords) will run once with `selectors = ()`.
1617        for selectors in itertools.product(*indexes):
1618            # Need to find the correct data in the DataArray based on the
1619            # dimension coordinates.
1620            filters = {k: v for d in selectors for k, v in d.items()}
1621
1622            # If `filters` is {}, then the DataArray is a single grib2 message
1623            # and da.sel(indexers={}) returns the DataArray.
1624            selected = da.sel(indexers=filters)
1625
1626            newmsg = Grib2Message(
1627                selected.attrs["GRIB2IO_section0"],
1628                selected.attrs["GRIB2IO_section1"],
1629                selected.attrs["GRIB2IO_section2"],
1630                selected.attrs["GRIB2IO_section3"],
1631                selected.attrs["GRIB2IO_section4"],
1632                selected.attrs["GRIB2IO_section5"],
1633            )
1634            newmsg.data = np.array(selected.data)
1635
1636            # For dimension coordinates, set the grib2 message metadata to the
1637            # dimension coordinate value.
1638            for index, value in filters.items():
1639                if index not in loc_indexes:
1640                    setattr(newmsg, index, value)
1641
1642            # For non-dimension coordinates, set the grib2 message metadata to
1643            # the DataArray coordinate value.
1644            for index in [i for i in coords_keys if i not in da.dims]:
1645                setattr(newmsg, index, selected.coords[index].values)
1646
1647            # Set section 5 attributes to the da.encoding dictionary.
1648            for key, value in selected.encoding.items():
1649                if key in ["dtype", "chunks", "original_shape"]:
1650                    continue
1651                setattr(newmsg, key, value)
1652
1653            # write the message to file
1654            with grib2io.open(filename, mode=mode) as f:
1655                f.write(newmsg)
1656            mode = "a"
1657
1658    def update_attrs(self, **kwargs):
1659        """
1660        Update many of the attributes of the DataArray.
1661
1662        Parameters
1663        ----------
1664        **kwargs
1665            Attributes to update.  This can include many of the GRIB2IO message
1666            attributes that you can find when you print a GRIB2IO message. For
1667            conflicting updates, the last keyword will be used.
1668
1669            +-----------------------+------------------------------------------+
1670            | kwargs                | Description                              |
1671            +=======================+==========================================+
1672            | shortName="VTMP"      | Set shortName to "VTMP", along with      |
1673            |                       | appropriate discipline,                  |
1674            |                       | parameterCategory, parameterNumber,      |
1675            |                       | fullName and units.                      |
1676            +-----------------------+------------------------------------------+
1677            | discipline=0,         | Set shortName, discipline,               |
1678            | parameterCategory=0,  | parameterCategory, parameterNumber,      |
1679            | parameterNumber=1     | fullName and units appropriate for       |
1680            |                       | "Virtual Temperature".                   |
1681            +-----------------------+------------------------------------------+
1682            | discipline=0,         | Conflicting keywords but                 |
1683            | parameterCategory=0,  | 'shortName="TMP"' wins.  Set shortName,  |
1684            | parameterNumber=1,    | discipline, parameterCategory,           |
1685            | shortName="TMP"       | parameterNumber, fullName and units      |
1686            |                       | appropriate for "Temperature".           |
1687            +-----------------------+------------------------------------------+
1688
1689        Returns
1690        -------
1691        DataArray
1692            DataArray with updated attributes.
1693        """
1694        da = self._obj.copy(deep=True)
1695
1696        newmsg = Grib2Message(
1697            da.attrs["GRIB2IO_section0"],
1698            da.attrs["GRIB2IO_section1"],
1699            da.attrs["GRIB2IO_section2"],
1700            da.attrs["GRIB2IO_section3"],
1701            da.attrs["GRIB2IO_section4"],
1702            da.attrs["GRIB2IO_section5"],
1703        )
1704
1705        coords_keys = [
1706            k
1707            for k in da.coords.keys()
1708            if k in AVAILABLE_NON_GEO_COORDS
1709        ]
1710
1711        for grib2_name, value in kwargs.items():
1712            if grib2_name == "gridDefinitionTemplateNumber":
1713                raise ValueError(
1714                    "The gridDefinitionTemplateNumber attribute cannot be updated.  The best way to change to a different grid is to interpolate the data to a new grid using the grib2io interpolate functions."
1715                )
1716            if grib2_name == "productDefinitionTemplateNumber":
1717                raise ValueError(
1718                    "The productDefinitionTemplateNumber attribute cannot be updated."
1719                )
1720            if grib2_name == "dataRepresentationTemplateNumber":
1721                raise ValueError(
1722                    "The dataRepresentationTemplateNumber attribute cannot be updated."
1723                )
1724            if grib2_name in coords_keys:
1725                warnings.warn(
1726                    f"Skipping attribute '{grib2_name}' because it is a coordinate. Use da.assign_coords() to change coordinate values."
1727                )
1728                continue
1729            if hasattr(newmsg, grib2_name):
1730                setattr(newmsg, grib2_name, value)
1731            else:
1732                warnings.warn(
1733                    f"Skipping attribute '{grib2_name}' because it is not a valid GRIB2 attribute for this message and cannot be updated."
1734                )
1735                continue
1736
1737        da.attrs["GRIB2IO_section0"] = newmsg.section0
1738        da.attrs["GRIB2IO_section1"] = newmsg.section1
1739        da.attrs["GRIB2IO_section2"] = newmsg.section2 or []
1740        da.attrs["GRIB2IO_section3"] = newmsg.section3
1741        da.attrs["GRIB2IO_section4"] = newmsg.section4
1742        da.attrs["GRIB2IO_section5"] = newmsg.section5
1743        da.attrs["fullName"] = newmsg.fullName
1744        da.attrs["shortName"] = newmsg.shortName
1745        da.attrs["units"] = newmsg.units
1746
1747        return da
1748
1749    def subset(self, lats, lons) -> xr.DataArray:
1750        """
1751        Subset the DataArray to a region defined by latitudes and longitudes.
1752
1753        Parameters
1754        ----------
1755        lats
1756            Latitude bounds of the region.
1757        lons
1758            Longitude bounds of the region.
1759
1760        Returns
1761        -------
1762        subset
1763            DataArray subset to the region.
1764        """
1765        da = self._obj.copy(deep=True)
1766
1767        newmsg = Grib2Message(
1768            da.attrs["GRIB2IO_section0"],
1769            da.attrs["GRIB2IO_section1"],
1770            da.attrs["GRIB2IO_section2"],
1771            da.attrs["GRIB2IO_section3"],
1772            da.attrs["GRIB2IO_section4"],
1773            da.attrs["GRIB2IO_section5"],
1774        )
1775
1776        newmsg.data = np.zeros((newmsg.ny, newmsg.nx), dtype=np.float32)
1777
1778        newmsg = newmsg.subset(lats, lons)
1779
1780        da.attrs["GRIB2IO_section3"] = newmsg.section3
1781
1782        mask_lat = (da.latitude >= newmsg.latitudeLastGridpoint) & (
1783            da.latitude <= newmsg.latitudeFirstGridpoint
1784        )
1785        mask_lon = (da.longitude >= newmsg.longitudeFirstGridpoint) & (
1786            da.longitude <= newmsg.longitudeLastGridpoint
1787        )
1788
1789        del newmsg
1790
1791        return da.where((mask_lon & mask_lat).compute(), drop=True)
1792
1793
1794def build_datatree_from_grib(filename, file_index, filters=None, stack_vertical=False):
1795    """
1796    Build a DataTree from GRIB2 messages.
1797
1798    Parameters
1799    ----------
1800    filename : str
1801        Path to the GRIB2 file.
1802    file_index : pandas.DataFrame
1803        DataFrame of GRIB2 message index.
1804    filters : dict, optional
1805        Filter criteria for GRIB2 messages.
1806    stack_vertical : bool, optional
1807        If True, vertical levels will be stacked in a single dataset
1808        instead of being organized in separate tree nodes.
1809
1810    Returns
1811    -------
1812    xarray.DataTree
1813        A hierarchical DataTree representation of the GRIB2 data.
1814    """
1815    if filters is None:
1816        filters = {}
1817
1818    # Apply any filters from user
1819    for k, v in filters.items():
1820        if k not in file_index.columns:
1821            file_index = file_index.copy()
1822            file_index[k] = file_index.msg.apply(lambda msg: getattr(msg, k, None))
1823        file_index = filter_index(file_index, k, v)
1824
1825    # Make a copy to avoid the SettingWithCopyWarning
1826    file_index = file_index.copy()
1827
1828    # Extract metadata needed for tree organization
1829    # Use a safer approach to handle missing attributes
1830    def safe_getattr(obj, name):
1831        try:
1832            attr = getattr(obj, name)
1833            # Need to test if the attribute is Grib2Metadata. If so,
1834            # then get the value attribute.
1835            if isinstance(attr, grib2io.templates.Grib2Metadata):
1836                attr = attr.value
1837            return attr
1838        except (AttributeError, KeyError):
1839            return None
1840
1841    for attr in _TREE_HIERARCHY_LEVELS:
1842        if (attr not in file_index.columns) and (attr != 'valueOfFirstFixedSurface'):
1843            file_index[attr] = file_index.msg.apply(lambda msg: safe_getattr(msg, attr))
1844
1845    # Also extract shortName for variable naming
1846    if 'shortName' not in file_index.columns:
1847        file_index = file_index.assign(shortName=file_index.msg.apply(lambda msg: getattr(msg, 'shortName', None)))
1848        file_index = file_index.assign(nx=file_index.msg.apply(lambda msg: getattr(msg, 'nx', None)))
1849        file_index = file_index.assign(ny=file_index.msg.apply(lambda msg: getattr(msg, 'ny', None)))
1850
1851    # Create root DataTree
1852    root = xr.DataTree()
1853
1854    # Adjust hierarchy levels if we're stacking vertical levels
1855    hierarchy_levels = list(_TREE_HIERARCHY_LEVELS) # This makes a copy
1856    if stack_vertical and "valueOfFirstFixedSurface" in hierarchy_levels:
1857        hierarchy_levels.remove("valueOfFirstFixedSurface")
1858
1859    # First group by level type
1860    level_groups = {}
1861
1862    # Create a dictionary to group data by level type
1863    for level_type in file_index['typeOfFirstFixedSurface'].unique():
1864        if pd.notna(level_type):  # Skip None/NaN values
1865            level_info = _LEVEL_NAME_MAPPING.get(level_type, f"level_{level_type}")
1866            level_name = level_info[0]
1867            level_source = level_info[1]
1868            # Get all rows for this level type
1869            level_data = file_index[file_index['typeOfFirstFixedSurface'] == level_type]
1870            level_groups[level_type] = {'name': level_name, 'data': level_data}
1871
1872    # Process each level group
1873    for level_type, group_info in level_groups.items():
1874        level_name = group_info['name']
1875        level_df = group_info['data']
1876
1877        # Create a branch for this level type
1878        level_tree = xr.DataTree()
1879
1880        # Process this branch based on PDTN, perturbation number, etc.
1881        process_level_branch(level_tree, level_df, filename)
1882
1883        # Add this branch to the main tree
1884        root[level_name] = level_tree
1885
1886    return root
1887
1888
1889def process_level_branch(level_tree, df, filename):
1890    """
1891    Process a level type branch of the data tree, organizing by PDTN and other attributes.
1892
1893    Parameters
1894    ----------
1895    level_tree : xarray.DataTree
1896        The DataTree node for this level type
1897    df : pandas.DataFrame
1898        DataFrame of messages for this level type
1899    filename : str
1900        Path to the GRIB2 file
1901    """
1902    # Group by PDTN
1903    pdtn_groups = {}
1904
1905    # Group data by PDTN first
1906    for pdtn_value in df['productDefinitionTemplateNumber'].unique():
1907        if pd.notna(pdtn_value):
1908            pdtn_df = df[df['productDefinitionTemplateNumber'] == pdtn_value]
1909            pdtn_groups[pdtn_value] = pdtn_df
1910
1911    # If there's only one PDTN value, skip creating PDTN branch level
1912    if len(pdtn_groups) == 1:
1913        pdtn, pdtn_df = next(iter(pdtn_groups.items()))
1914
1915        pdtn_name = f"pdtn_{int(pdtn)}"
1916
1917        # Check if we need to further subdivide by perturbation number
1918        has_perturbations = ('perturbationNumber' in pdtn_df.columns and
1919                             len(pdtn_df['perturbationNumber'].dropna().unique()) > 1)
1920
1921        # Check if we need to further subdivide by probabilities unique for each variable.
1922        has_probabilities = ('typeOfProbability' in pdtn_df.columns and
1923                             len(pdtn_df['typeOfProbability'].dropna().unique()) > 1)
1924
1925        if has_perturbations:
1926            # Process perturbations directly on the level tree
1927            process_perturbation_groups(level_tree, pdtn_df, filename)
1928        elif has_probabilities:
1929            # Process probability groups
1930            process_probability_groups(level_tree, pdtn_df, filename)
1931        else:
1932            # Try to create dataset directly on level
1933            try:
1934                dss = create_datasets_from_df(pdtn_df, filename)
1935                if dss is not None:
1936                    dt = xr.DataTree()
1937                    if len(dss) == 1:
1938                        dt.ds = dss[0]
1939                    else:
1940                        for ds in dss:
1941                            varname = list(ds.data_vars)[0]
1942                            dt[f"var_{varname}"] = ds
1943                    level_tree[pdtn_name] = dt
1944            except Exception as e:
1945                print(f"Error creating dataset for level with pdtn {int(pdtn)}: {e}")
1946
1947                # Try to separate by variable name as a fallback
1948                try_process_by_variables(level_tree, pdtn_df, filename)
1949    else:
1950        # Multiple PDTN values, process each group with PDTN branch nodes
1951        for pdtn, pdtn_df in pdtn_groups.items():
1952            # Use a simple node name that's easy to use in code
1953            pdtn_name = f"pdtn_{int(pdtn)}"
1954
1955            # Check if we need to further subdivide by perturbation number
1956            has_perturbations = ('perturbationNumber' in pdtn_df.columns and
1957                                 len(pdtn_df['perturbationNumber'].dropna().unique()) > 1)
1958
1959            # Check if we need to further subdivide by probabilities unique for each variable.
1960            has_probabilities = ('typeOfProbability' in pdtn_df.columns and
1961                                 len(pdtn_df['typeOfProbability'].dropna().unique()) > 1)
1962
1963            if has_perturbations:
1964                # Create a branch for this PDTN
1965                pdtn_tree = xr.DataTree()
1966
1967                # Process perturbation groups
1968                process_perturbation_groups(pdtn_tree, pdtn_df, filename)
1969
1970                # Only add the PDTN branch if it has children
1971                if len(pdtn_tree.children) > 0 or pdtn_tree.ds is not None:
1972                    level_tree[pdtn_name] = pdtn_tree
1973            elif has_probabilities:
1974                # Create a branch for this PDTN
1975                pdtn_tree = xr.DataTree()
1976
1977                # Process probability groups
1978                process_probability_groups(pdtn_tree, pdtn_df, filename)
1979
1980                # Only add the PDTN branch if it has children
1981                if len(pdtn_tree.children) > 0 or pdtn_tree.ds is not None:
1982                    level_tree[pdtn_name] = pdtn_tree
1983            else:
1984                # Create a subtree for this PDTN
1985                pdtn_tree = xr.DataTree()
1986
1987                # Try to create dataset directly on level
1988                try:
1989                    dss = create_datasets_from_df(pdtn_df, filename)
1990                    if dss is not None:
1991                        if len(dss) == 1:
1992                            pdtn_tree.ds = dss[0]
1993                        else:
1994                            for ds in dss:
1995                                varname = list(ds.data_vars)[0]
1996                                pdtn_tree[f"var_{varname}"] = ds
1997                        level_tree[pdtn_name] = pdtn_tree
1998                except Exception as e:
1999                    print(f"Error creating dataset for level with pdtn {int(pdtn)}: {e}")
2000
2001                    # Try to separate by variable name as a fallback
2002                    try_process_by_variables(pdtn_tree, pdtn_df, filename)
2003                    level_tree[pdtn_name] = pdtn_tree
2004
2005
2006def process_probability_groups(target_tree, pdtn_df, filename):
2007    """
2008    """
2009    success = False
2010    # Group by type of probability
2011    prob_groups = {}
2012    for prob_value in pdtn_df['typeOfProbability'].unique():
2013        if pd.notna(prob_value):
2014            prob_df = pdtn_df[pdtn_df['typeOfProbability'] == prob_value]
2015            prob_groups[prob_value] = prob_df
2016
2017    # Process each probability group
2018    prob_dict = {}
2019    for prob_num, prob_df in prob_groups.items():
2020        prob_name = f"prob_{int(prob_num)}"
2021
2022        # Try to create dataset for this probability group
2023        try:
2024            dss = create_datasets_from_df(prob_df, filename)
2025            dt = xr.DataTree()
2026            if len(dss) == 1:
2027                dt.ds = dss[0]
2028                target_tree[prob_name] = dt
2029            elif len(dss) > 1:
2030                for ds in dss:
2031                    dt[f"var_{ds.data_vars[0]}"] = ds
2032            target_tree[prob_name] = dt
2033        except Exception as e:
2034            # Log error but continue processing other groups
2035            print(f"Error creating dataset for type of probability {prob_name}: {e}")
2036
2037    return success
2038
2039
2040def process_perturbation_groups(target_tree, pdtn_df, filename):
2041    """
2042    Process perturbation groups and add them to the target tree.
2043
2044    Parameters
2045    ----------
2046    target_tree : xarray.DataTree
2047        The tree node to add perturbation groups to
2048    pdtn_df : pandas.DataFrame
2049        DataFrame of messages for a specific PDTN
2050    filename : str
2051        Path to the GRIB2 file
2052
2053    Returns
2054    -------
2055    bool
2056        True if at least one perturbation was successfully processed
2057    """
2058    success = False
2059    # Group by perturbation number
2060    pert_groups = {}
2061    for pert_value in pdtn_df['perturbationNumber'].unique():
2062        if pd.notna(pert_value):
2063            pert_df = pdtn_df[pdtn_df['perturbationNumber'] == pert_value]
2064            pert_groups[pert_value] = pert_df
2065
2066    # Process each perturbation group
2067    for pert_num, pert_df in pert_groups.items():
2068        pert_name = f"pert_{int(pert_num)}"
2069
2070        ## Try to create dataset for this perturbation group
2071        #try:
2072        #    dss = create_datasets_from_df(pert_df, filename)
2073        #    if dss is not None:
2074        #        if len(dss) == 1:
2075        #            target_tree.ds = dss[0]
2076        #        else:
2077        #            dss_dict = {f"ds_{i}": ds for i, ds in enumerate(dss)}
2078        #            atree = xr.DataTree(dss_dict)
2079        #            target_tree[prob_name] = atree
2080        #        success = True
2081        #except Exception as e:
2082        #    # Log error but continue processing other groups
2083        #    print(f"Error creating dataset for perturbation {pert_name}: {e}")
2084
2085        # Try to create dataset for this perturbation group
2086        try:
2087            dss = create_datasets_from_df(pert_df, filename)
2088            dt = xr.DataTree()
2089            if len(dss) == 1:
2090                dt.ds = dss[0]
2091                target_tree[pert_name] = dt
2092            elif len(dss) > 1:
2093                for ds in dss:
2094                    dt[f"pert{ds.data_vars[0]}"] = ds
2095            target_tree[pert_name] = dt
2096        except Exception as e:
2097            # Log error but continue processing other groups
2098            print(f"Error creating dataset for perturbation {pert_name}: {e}")
2099
2100    return success
2101
2102
2103def try_process_by_variables(target_tree, df, filename):
2104    """
2105    Try to separate data by variable names and create datasets.
2106
2107    Parameters
2108    ----------
2109    target_tree : xarray.DataTree
2110        The tree node to add variable datasets to
2111    df : pandas.DataFrame
2112        DataFrame of messages
2113    filename : str
2114        Path to the GRIB2 file
2115
2116    Returns
2117    -------
2118    bool
2119        True if at least one variable was successfully processed
2120    """
2121    success = False
2122
2123    try:
2124        for var_name in df['shortName'].unique():
2125            if pd.notna(var_name):
2126                var_df = df[df['shortName'] == var_name]
2127                try:
2128                    var_ds = create_datasets_from_df(var_df, filename)
2129                    if var_ds is not None:
2130                        target_tree[f"var_{var_name}"] = var_ds[0]
2131                        success = True
2132                except Exception as var_e:
2133                    print(f"Error creating dataset for variable {var_name}: {var_e}")
2134    except Exception as nested_e:
2135        print(f"Failed to process variables: {nested_e}")
2136
2137    return success
2138
2139
2140def create_datasets_from_df(
2141    df,
2142    filename,
2143    verbose=False
2144) -> typing.Optional[typing.List[xr.Dataset]]:
2145    """
2146    Create a list of xarray Datasets from a DataFrame of messages.
2147
2148    Parameters
2149    ----------
2150    df : pandas.DataFrame
2151        DataFrame of GRIB messages
2152    filename : str
2153        Path to the GRIB2 file
2154    verbose : bool, optional
2155        If True, prints detailed debugging information
2156
2157    Returns
2158    -------
2159    dss
2160        List of Datasets, or None if creation failed
2161    """
2162    try:
2163        if verbose:
2164            print(f"\n==== VERBOSE DEBUG INFO ====")
2165            print(f"Creating dataset from DataFrame with {len(df)} messages")
2166            print(f"DataFrame columns: {df.columns.tolist()}")
2167
2168            if 'shortName' in df.columns:
2169                print(f"Variables in group: {df['shortName'].unique().tolist()}")
2170
2171            if 'valueOfFirstFixedSurface' in df.columns:
2172                print(f"Vertical levels: {df['valueOfFirstFixedSurface'].unique().tolist()}")
2173
2174        # Process by variables
2175        datasets = {}
2176
2177        # Process each variable separately, regardless of whether there are vertical levels
2178        for var_name, var_df in df.groupby('shortName'):
2179            if verbose:
2180                print(
2181                    f"\n  Processing variable: {var_name} with {len(var_df)} messages, with pdtn(s) = {var_df['productDefinitionTemplateNumber'].unique()}")
2182
2183            # Process vertical levels if present
2184            if 'valueOfFirstFixedSurface' in var_df.columns and len(var_df['valueOfFirstFixedSurface'].unique()) > 1:
2185                if verbose:
2186                    print(f"  Variable {var_name} has multiple vertical levels")
2187                # Process each level separately
2188                level_das = []
2189
2190                for level, level_df in var_df.groupby('valueOfFirstFixedSurface'):
2191                    if verbose:
2192                        print(f"    Processing level {level} with {len(level_df)} messages")
2193                    try:
2194                        # Parse the index and get dimensions for this level
2195                        file_index, non_geo_dims, attrs, coord_attrs = parse_grib_index(level_df, {})
2196                        # Remove valueOfFirstFixedSurface from dimensions since we're handling it separately
2197                        non_geo_dims = [d for d in non_geo_dims if d.__name__ != "ValueOfFirstFixedSurfaceDim"]
2198
2199                        frames, cube, extra_geo = make_variables(
2200                            file_index, filename, non_geo_dims, allow_uneven_dims=True)
2201
2202                        if frames is not None and len(frames) == 1:
2203                            level_da = build_da_without_coords(frames[0], cube, filename, attrs)
2204                            # Add this level to the list with its level value as coord
2205                            level_da = level_da.assign_coords(valueOfFirstFixedSurface=level)
2206                            level_das.append(level_da)
2207                    except Exception as e:
2208                        if verbose:
2209                            print(f"    Error processing level {level} for {var_name}: {e}")
2210
2211                if level_das:
2212                    # Combine all levels into a single DataArray along the valueOfFirstFixedSurface dimension
2213                    if verbose:
2214                        print(f"    Combining {len(level_das)} levels for {var_name}")
2215                    try:
2216                        combined_da = xr.concat(level_das, dim='valueOfFirstFixedSurface')
2217                        # Create a simple dataset with just this variable
2218                        var_ds = xr.Dataset({var_name: combined_da})
2219                        # Assign the coords from the first level's cube
2220                        var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs)
2221                       # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords?
2222                       # var_ds = var_ds.assign_coords(coords_from_cube(cube))
2223                       # Add extra geo coords
2224                       # if extra_geo:
2225                       #    var_ds = var_ds.assign_coords(extra_geo)
2226                       # Add valid date coords if available
2227                       # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords:
2228                       #    var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime']))
2229
2230                        # Store this variable's dataset
2231                        datasets[var_name] = var_ds
2232                        if verbose:
2233                            print(f"    Created dataset for {var_name} with levels")
2234                    except Exception as e:
2235                        if verbose:
2236                            print(f"    Error combining levels for {var_name}: {e}")
2237            else:
2238                # Single level or no vertical levels
2239                if verbose:
2240                    print(f"  Variable {var_name} is a single level or has no vertical dimension")
2241                try:
2242                    # Parse the index and get dimensions
2243                    file_index, non_geo_dims, attrs, coord_attrs = parse_grib_index(var_df, {})
2244                    frames, cube, extra_geo = make_variables(file_index, filename, non_geo_dims, allow_uneven_dims=True)
2245
2246                    if frames is not None and len(frames) == 1:
2247                        # Create dataset with this variable
2248                        var_ds = xr.Dataset()
2249                        da = build_da_without_coords(frames[0], cube, filename, attrs)
2250                        var_ds[da.name] = da
2251
2252                        # Assign coords
2253                        var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs)
2254                       # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords?
2255                       # var_ds = var_ds.assign_coords(coords_from_cube(cube))
2256                       # if extra_geo:
2257                       #    var_ds = var_ds.assign_coords(extra_geo)
2258                       # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords:
2259                       #    var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime']))
2260
2261                        # Store this variable's dataset
2262                        datasets[var_name] = var_ds
2263                        if verbose:
2264                            print(f"  Created dataset for {var_name}")
2265                    elif frames is not None and len(frames) > 1:
2266                        if verbose:
2267                            print(f"  Variable {var_name} has multiple frames, possibly different parameters")
2268                        # Just use the first frame for now (simplified approach)
2269                        var_ds = xr.Dataset()
2270                        da = build_da_without_coords(frames[0], cube, filename, attrs)
2271                        var_ds[da.name] = da
2272
2273                        # Assign coords
2274                        var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs)
2275                       # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords?
2276                       # var_ds = var_ds.assign_coords(coords_from_cube(cube))
2277                       # if extra_geo:
2278                       #    var_ds = var_ds.assign_coords(extra_geo)
2279                       # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords:
2280                       #    var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime']))
2281
2282                        datasets[var_name] = var_ds
2283                        if verbose:
2284                            print(f"  Created dataset with first frame for {var_name}")
2285                except Exception as e:
2286                    if verbose:
2287                        print(f"  Error processing variable {var_name}: {e}")
2288
2289        # Attempt to merge all the variable datasets
2290        if datasets:
2291            try:
2292                if verbose:
2293                    print(f"\nMerging {len(datasets)} datasets...")
2294                # Get the list of datasets to merge
2295                ds_list = list(datasets.values())
2296
2297                # Try merging them all at once
2298                try:
2299                    combined_ds = xr.merge(ds_list)
2300                    if verbose:
2301                        print(f"Successfully merged all datasets into one.")
2302                        print(f"Final dataset has variables: {list(combined_ds.data_vars)}")
2303                        print(f"==== END VERBOSE DEBUG INFO ====\n")
2304                    return [combined_ds]
2305                except Exception as merge_error:
2306                    if verbose:
2307                        print(f"Error merging all datasets: {merge_error}")
2308                    return ds_list
2309            except Exception as e:
2310                if verbose:
2311                    print(f"Error in final merge process: {e}")
2312                    print(f"==== END VERBOSE DEBUG INFO ====\n")
2313                return None
2314        else:
2315            if verbose:
2316                print(f"No datasets were created for any variables")
2317                print(f"==== END VERBOSE DEBUG INFO ====\n")
2318            return None
2319
2320    except Exception as e:
2321        # If there's an error, log it and return None
2322        if verbose:
2323            print(f"Error creating dataset: {e}")
2324            import traceback
2325            traceback.print_exc()
2326            print(f"==== END VERBOSE DEBUG INFO ====\n")
2327        return None
2328
2329
2330# Only register the DataTree accessor if DataTree is supported
2331if _HAS_DATATREE:
2332    @xr.register_datatree_accessor("grib2io")
2333    class Grib2ioDataTree:
2334        """
2335        DataTree accessor for GRIB2 files.
2336
2337        This accessor provides methods for working with GRIB2 data organized
2338        in a hierarchical tree structure.
2339        """
2340
2341        def __init__(self, datatree_obj):
2342            self._obj = datatree_obj
2343
2344        def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
2345            """
2346            Write all datasets in the DataTree to a GRIB2 file.
2347
2348            Parameters
2349            ----------
2350            filename : str
2351                Name of the GRIB2 file to write to.
2352            mode : {"x", "w", "a"}, optional
2353                Persistence mode, default is "x" (create, fail if exists)
2354            """
2355            # Start with the specified mode
2356            current_mode = mode
2357
2358            # Function to recursively process the tree
2359            def process_tree(node):
2360                nonlocal current_mode
2361
2362                # If this is a Dataset node with data variables
2363                if node.ds is not None and node.ds.data_vars:
2364                    # Write dataset to GRIB2 file
2365                    node.ds.grib2io.to_grib2(filename, mode=current_mode)
2366                    # Switch to append mode after first write
2367                    current_mode = "a"
2368
2369                # Process children
2370                for child_name, child_node in node.children.items():
2371                    process_tree(child_node)
2372
2373            # Start processing from the root
2374            process_tree(self._obj)
2375
2376        def griddef(self):
2377            """
2378            Get the grid definition from the first dataset in the tree that has one.
2379
2380            Returns
2381            -------
2382            grib2io.Grib2GridDef
2383                Grid definition object
2384            """
2385            # Function to find first dataset with GRIB2IO_section3
2386            def find_griddef(node):
2387                if node.ds is not None and node.ds.data_vars:
2388                    for var_name in node.ds.data_vars:
2389                        if 'GRIB2IO_section3' in node.ds[var_name].attrs:
2390                            return Grib2GridDef.from_section3(node.ds[var_name].attrs['GRIB2IO_section3'])
2391
2392                # Check children
2393                for child_name, child_node in node.children.items():
2394                    griddef = find_griddef(child_node)
2395                    if griddef is not None:
2396                        return griddef
2397
2398                return None
2399
2400            return find_griddef(self._obj)
2401
2402        def interp(self, method, grid_def_out, method_options=None, num_threads=1):
2403            """
2404            Interpolate all datasets in the tree to a new grid.
2405
2406            Parameters
2407            ----------
2408            method : str or int
2409                Interpolation method to use
2410            grid_def_out : grib2io.Grib2GridDef
2411                Target grid definition
2412            method_options : list, optional
2413                Options for interpolation method
2414            num_threads : int, optional
2415                Number of threads to use for interpolation
2416
2417            Returns
2418            -------
2419            xarray.DataTree
2420                New DataTree with interpolated data
2421            """
2422            new_tree = xr.DataTree()
2423
2424            # Function to recursively process the tree
2425            def process_tree(node, new_parent):
2426                # If this is a Dataset node with data variables
2427                if node.ds is not None and node.ds.data_vars:
2428                    # Interpolate dataset
2429                    interp_ds = node.ds.grib2io.interp(method, grid_def_out,
2430                                                       method_options=method_options,
2431                                                       num_threads=num_threads)
2432
2433                    # Add to new tree at the same path
2434                    if node == self._obj:  # Root node
2435                        new_parent.ds = interp_ds
2436                    else:
2437                        new_parent.ds = interp_ds
2438
2439                # Process children
2440                for child_name, child_node in node.children.items():
2441                    # Create same child in new tree
2442                    new_child = xr.DataTree()
2443                    new_parent[child_name] = new_child
2444                    process_tree(child_node, new_child)
2445
2446            # Start processing from the root
2447            process_tree(self._obj, new_tree)
2448
2449            return new_tree
2450
2451        def subset(self, lats, lons):
2452            """
2453            Subset all datasets in the tree to a region.
2454
2455            Parameters
2456            ----------
2457            lats : list or tuple
2458                Latitude bounds [min_lat, max_lat]
2459            lons : list or tuple
2460                Longitude bounds [min_lon, max_lon]
2461
2462            Returns
2463            -------
2464            xarray.DataTree
2465                New DataTree with subset data
2466            """
2467            new_tree = xr.DataTree()
2468
2469            # Function to recursively process the tree
2470            def process_tree(node, new_parent):
2471                # If this is a Dataset node with data variables
2472                if node.ds is not None and node.ds.data_vars:
2473                    # Subset dataset
2474                    subset_ds = node.ds.grib2io.subset(lats, lons)
2475
2476                    # Add to new tree at the same path
2477                    if node == self._obj:  # Root node
2478                        new_parent.ds = subset_ds
2479                    else:
2480                        new_parent.ds = subset_ds
2481
2482                # Process children
2483                for child_name, child_node in node.children.items():
2484                    # Create same child in new tree
2485                    new_child = xr.DataTree()
2486                    new_parent[child_name] = new_child
2487                    process_tree(child_node, new_child)
2488
2489            # Start processing from the root
2490            process_tree(self._obj, new_tree)
2491
2492            return new_tree
AVAILABLE_NON_GEO_COORDS = ['duration', 'leadTime', 'percentileValue', 'perturbationNumber', 'refDate', 'thresholdLowerLimit', 'thresholdUpperLimit', 'valueOfFirstFixedSurface', 'valueOfSecondFixedSurface', 'aerosolType', 'scaledValueOfFirstWavelength', 'scaledValueOfSecondWavelength', 'scaledValueOfCentralWaveNumber', 'scaledValueOfFirstSize', 'scaledValueOfSecondSize']

Available non-geographic coordinate names.

AVAILABLE_NON_GEO_DIMS = ['duration', 'leadTime', 'percentileValue', 'perturbationNumber', 'refDate', 'threshold', 'level']

Available non-geographic dimension names.

VERTICAL_COORDINATE_SURFACES = ['Ground or Water Surface', 'Isothermal Level', 'Specified radius from the centre of the Sun', 'Isobaric Surface', 'Mean Sea Level', 'Specific Altitude Above Mean Sea Level', 'Specified Height Level Above Ground', 'Sigma Level', 'Hybrid Level', 'Depth Below Land Surface', 'Isentropic (theta) Level', 'Level at Specified Pressure Difference from Ground to Level', 'Potential Vorticity Surface', 'Eta Level', 'Logarithmic Hybrid Level', 'Sigma height level', 'Hybrid Height Level', 'Hybrid Pressure Level', 'Soil level', 'Sea-ice level', 'Depth Below Sea Level', 'Depth Below Water Surface', 'Ocean Model Level', 'Ocean level defined by water density (sigma-theta) difference from near-surface to level', 'Ocean level defined by water potential temperature difference from near-surface to level', 'Ocean level defined by vertical eddy diffusivity difference from near-surface to level', 'Ocean level defined by water density (rho) difference from near-surface to level']

Lookup table to define surface types that should be parsed as vertical coordinates when data_model="nws-viz".

def parse_data_model(ds, data_model):
141def parse_data_model(ds, data_model):
142    """
143    Normalize a GRIB2-derived Dataset to a target data model (currently ``"nws-viz"``).
144
145    When ``data_model == "nws-viz"``, this function converts coordinate and
146    variable names to snake_case, derives CF-like metadata, promotes select
147    GRIB-derived quantities to coordinates, optionally swaps dimensions, and
148    standardizes units/attributes. If ``data_model`` is anything else, the
149    input dataset is returned unchanged.
150
151    Parameters
152    ----------
153    ds : xarray.Dataset
154        GRIB2-derived dataset whose variables and attributes follow the
155        conventions emitted by ``grib2io``. Expected to contain GRIB-related
156        attributes such as ``typeOfFirstFixedSurface``,
157        ``typeOfSecondFixedSurface``, and (for probabilistic variables)
158        ``typeOfProbability``.
159    data_model : str
160        Target data model name. Only the value ``"nws-viz"`` triggers
161        transformations.
162
163    Returns
164    -------
165    xarray.Dataset
166        A new dataset with:
167        * Selected coordinates renamed:
168          ``refDate -> forecast_reference_time``,
169          ``leadTime -> lead_time``,
170          ``validDate -> time``,
171          ``percentileValue -> percentile``,
172          ``thresholdLowerLimit -> threshold_lower_limit``,
173          ``thresholdUpperLimit -> threshold_upper_limit``.
174        * Vertical coordinates derived from
175          ``valueOfFirstFixedSurface`` / ``valueOfSecondFixedSurface`` and their
176          corresponding ``typeOf*FixedSurface`` definitions. New coordinate
177          names are generated from the surface definition (lowercased, spaces
178          to underscores, punctuation removed). If the name already exists, a
179          ``"_2"`` suffix is appended.
180        * Possible dimension swaps:
181          ``level -> <derived_vertical_coord>`` when present; and for
182          probabilistic variables, ``threshold -> threshold_lower_limit`` or
183          ``threshold -> threshold_upper_limit`` when
184          ``typeOfProbability`` indicates the appropriate semantics.
185        * Variable names lowercased; dataset- and variable-level attributes
186          converted to snake_case (except GRIB section attributes which are
187          normalized to ``grib...``).
188        * CF-adjacent metadata populated: ``standard_name`` and
189          ``cell_methods`` are set via the shortname→CF lookup table.
190        * Percent units normalized from ``"%"`` to ``"percent"`` on coordinates.
191        * For precipitation type (``PTYPE``) thresholds, numeric codes are
192          decoded to strings (GRIB2 Table 4.201) in relevant attrs/coords.
193
194    Notes
195    -----
196    - Precipitation type decoding uses GRIB2 Table 4.201 via
197      ``tables.get_value_from_table(code, "4.201")`` and returns a NumPy
198      array with ``np.dtypes.StringDType``.
199    - CF-related lookups are performed using
200      ``tables.get_table("shortname_to_cf")``.
201    - Vertical coordinate surface names are validated against
202      ``VERTICAL_COORDINATE_SURFACES`` before promotion to coordinates.
203
204    Warnings
205    --------
206    This function assumes the presence of certain GRIB-derived attributes on the
207    first data variable (e.g., ``typeOfFirstFixedSurface``,
208    ``typeOfSecondFixedSurface``, and possibly ``typeOfProbability``).
209    If these are absent or malformed, errors (e.g., ``KeyError``) may occur.
210
211    Examples
212    --------
213    >>> ds2 = parse_data_model(ds, "nws-viz")
214    >>> list(ds2.coords)
215    ['forecast_reference_time', 'lead_time', 'time', 'percentile', ...]
216    """
217
218    def _decode_ptype(values):
219        """
220        Decode precipitation type values into human-readable strings.
221
222        Uses GRIB2 Table 4.201 to map numeric codes to precipitation type descriptions.
223
224        Parameters
225        ----------
226        values : array_like
227            Array of numeric precipitation type codes (e.g., integers or floats).
228            Each value corresponds to a GRIB2 Table 4.201 precipitation type code.
229
230        Returns
231        -------
232        numpy.ndarray
233            Array of decoded precipitation type strings with
234            NumPy’s flexible string data type (`np.dtypes.StringDType`).
235        """
236        results = []
237        for val in values:
238            # Convert each numeric code to string and look up in Table 4.201
239            results.append(str(tables.get_value_from_table(str(int(val)), '4.201')))
240
241        # Return array of strings using numpy's string data type
242        return np.array(results, dtype=np.dtypes.StringDType)
243
244    # convert coordinates and attributes to CF if requested
245    if data_model == 'nws-viz':
246
247        # define regex to convert to snake case
248        pattern = re.compile(r'(?<!^)(?=[A-Z])')
249
250        # check for coordinates and rename
251        for coord in ds.coords:
252            if coord == 'refDate':
253                ds = ds.rename({'refDate': 'forecast_reference_time'})
254
255            elif coord == 'leadTime':
256                ds = ds.rename({'leadTime': 'lead_time'})
257
258            elif coord == 'validDate':
259                ds = ds.rename({'validDate': 'time'})
260
261            elif coord == 'percentileValue':
262                ds = ds.rename({'percentileValue': 'percentile'})
263
264            elif coord == 'thresholdLowerLimit':
265                ds = ds.rename({'thresholdLowerLimit': 'threshold_lower_limit'})
266                ds['threshold_lower_limit'].attrs['long_name'] = 'Threshold Lower Limit'
267                ds['threshold_lower_limit'].attrs['units'] = ds[list(ds.data_vars.keys())[0]].attrs['units']
268
269                if 'PTYPE' in ds.data_vars:
270                    ds['threshold_lower_limit'] = xr.apply_ufunc(_decode_ptype, ds['threshold_lower_limit'])
271
272                # check if thresholdLowerLimit should be a dimension coordinate
273                if 'threshold' in ds.dims:
274                    var_key = list(ds.data_vars.keys())[0]
275                    prob_types = [
276                        'Probability of event below lower limit',
277                        'Probability of event above lower limit',
278                        'Probability of event equal to lower limit',
279                        'Probability of event between upper and lower limits (the range includes lower limit but not the upper limit)'
280                    ]
281                    if ds[var_key].attrs['typeOfProbability'] in prob_types:
282                        ds = ds.swap_dims({'threshold': 'threshold_lower_limit'})
283
284            elif coord == 'thresholdUpperLimit':
285                ds = ds.rename({'thresholdUpperLimit': 'threshold_upper_limit'})
286                ds['threshold_upper_limit'].attrs['long_name'] = 'Threshold Upper Limit'
287                ds['threshold_upper_limit'].attrs['units'] = ds[list(ds.data_vars.keys())[0]].attrs['units']
288
289                if 'PTYPE' in ds.data_vars:
290                    ds['threshold_upper_limit'] = xr.apply_ufunc(_decode_ptype, ds['threshold_upper_limit'])
291
292                if 'threshold' in ds.dims:
293                    var_key = list(ds.data_vars.keys())[0]
294                    prob_types = [
295                        'Probability of event below upper limit',
296                        'Probability of event above upper limit'
297                    ]
298                    if ds[var_key].attrs['typeOfProbability'] in prob_types:
299                        ds = ds.swap_dims({'threshold': 'threshold_upper_limit'})
300
301            # If the dataset has valueOfFirstFixedSurface as a coordinate
302            elif coord == 'valueOfFirstFixedSurface':
303                # Get the valueOfFirstFixedSurface coordinate
304                da = ds.valueOfFirstFixedSurface
305
306                # Get the definition and units from typeOfFirstFixedSurface
307                var_key = list(ds.data_vars.keys())[0]
308                definition, units = ds[var_key].attrs['typeOfFirstFixedSurface']
309
310                if definition in VERTICAL_COORDINATE_SURFACES:
311                    # Convert definition to lowercase and replace spaces with underscores
312                    key = definition.lower().replace(' ', '_')
313
314                    # remove special characters
315                    key = re.sub(r'[^a-z0-9_]', '', key)
316
317                    # Add units and grib_name attributes
318                    da.attrs['units'] = units
319                    da.attrs['grib_name'] = ['valueOfFirstFixedSurface', 'typeOfFirstFixedSurface']
320
321                    # Assign the coordinate with the new key name
322                    ds = ds.assign_coords({key: da})
323
324                    # If valueOfFirstFixedSurface is a dimension, swap it with the new key
325                    if 'level' in ds.dims:
326                        ds = ds.swap_dims({"level": key})
327
328                # Remove the original coordinates
329                del ds['valueOfFirstFixedSurface']
330
331            # If the dataset has valueOfSecondFixedSurface as a coordinate
332            elif coord == 'valueOfSecondFixedSurface':
333                # Get the valueOfSecondFixedSurface coordinate
334                da = ds.valueOfSecondFixedSurface
335
336                # Get the definition and units from typeOfSecondFixedSurface
337                var_key = list(ds.data_vars.keys())[0]
338                definition, units = ds[var_key].attrs['typeOfSecondFixedSurface']
339
340                if definition in VERTICAL_COORDINATE_SURFACES:
341                    # Convert definition to lowercase and replace spaces with underscores
342                    key = definition.lower().replace(' ', '_')
343
344                    # remove special characters
345                    key = re.sub(r'[^a-z0-9_]', '', key)
346
347                    # check if key is already in coords
348                    if key in ds.coords:
349                        key = key + '_2'
350
351                    # Add units and grib_name attributes
352                    da.attrs['units'] = units
353                    da.attrs['grib_name'] = ['valueOfSecondFixedSurface', 'typeOfSecondFixedSurface']
354
355                    # Assign the coordinate with the new key name
356                    ds = ds.assign_coords({key: da})
357
358                # Remove the original coordinates
359                del ds['valueOfSecondFixedSurface']
360            else:
361                # change coord name to snake case
362                new_coord_name = pattern.sub('_', coord).lower()
363                ds = ds.rename({coord: new_coord_name})
364
365        # convert all attributes and variable names to snake case
366        for var in ds.data_vars:
367            da = ds[var]
368            record = tables.get_table('shortname_to_cf').get(da.name)
369            da.attrs['standard_name'] = 'unknown' if record is None else record['cf_standard_name']
370            da.attrs['cell_methods'] = 'unknown' if record is None else record['cf_cell_methods']
371
372            ds[var] = da
373
374            # rename variable
375            new_var_name = var.lower()
376            ds = ds.rename({var: new_var_name})
377
378            # remove attr for typeOfFirstFixedSurface (applied as coordinate above)
379            if 'typeOfFirstFixedSurface' in ds[new_var_name].attrs:
380                definition, units = ds[new_var_name].attrs['typeOfFirstFixedSurface']
381                ds[new_var_name].attrs['typeOfFirstFixedSurface'] = f'{definition} ({units})'
382
383            if 'typeOfSecondFixedSurface' in ds[new_var_name].attrs:
384                definition, units = ds[new_var_name].attrs['typeOfSecondFixedSurface']
385                ds[new_var_name].attrs['typeOfSecondFixedSurface'] = f'{definition} ({units})'
386
387            ds[new_var_name].attrs.pop('percentileValue', None)
388
389            if 'threshold_lower_limit' in ds.coords:
390                ds[new_var_name].attrs.pop('thresholdLowerLimit', None)
391
392            if 'threshold_upper_limit' in ds.coords:
393                ds[new_var_name].attrs.pop('thresholdUpperLimit', None)
394
395            for attr in list(ds[new_var_name].attrs.keys()):
396                # skip grib section attrs
397                if 'GRIB2IO_section' in attr:
398                    # replace GRIB2IO with grib in attr
399                    new_attr_name = attr.replace('GRIB2IO', 'grib')
400                else:
401                    # change attr name to snake case
402                    new_attr_name = pattern.sub('_', attr).lower()
403
404                # update new attr name for specific CF names
405                if new_attr_name == 'full_name':
406                    new_attr_name = 'long_name'
407
408                # change % to percent
409                if attr == 'units' and ds[new_var_name].attrs[attr] == '%':
410                    ds[new_var_name].attrs[attr] = 'percent'
411
412                if new_var_name == 'ptype' and 'threshold' in new_attr_name:
413                    value = ds[new_var_name].attrs.pop(attr)
414                    ds[new_var_name].attrs[attr] = _decode_ptype(value)
415                else:
416                    # change attr name in attrs
417                    ds[new_var_name].attrs[new_attr_name] = ds[new_var_name].attrs.pop(attr)
418
419
420        # change dataset attrs to snake case
421        for attr in list(ds.attrs.keys()):
422            # change attr name to snake case
423            new_attr_name = pattern.sub('_', attr).lower()
424
425            # change attr name in attrs
426            ds.attrs[new_attr_name] = ds.attrs.pop(attr)
427
428        # change % to percent
429        for coord in ds.coords:
430            if 'units' in ds[coord].attrs and ds[coord].attrs['units'] == '%':
431                ds[coord].attrs['units'] = 'percent'
432
433    return ds

Normalize a GRIB2-derived Dataset to a target data model (currently "nws-viz").

When data_model == "nws-viz", this function converts coordinate and variable names to snake_case, derives CF-like metadata, promotes select GRIB-derived quantities to coordinates, optionally swaps dimensions, and standardizes units/attributes. If data_model is anything else, the input dataset is returned unchanged.

Parameters
  • ds (xarray.Dataset): GRIB2-derived dataset whose variables and attributes follow the conventions emitted by grib2io. Expected to contain GRIB-related attributes such as typeOfFirstFixedSurface, typeOfSecondFixedSurface, and (for probabilistic variables) typeOfProbability.
  • data_model (str): Target data model name. Only the value "nws-viz" triggers transformations.
Returns
  • xarray.Dataset: A new dataset with:
    • Selected coordinates renamed: refDate -> forecast_reference_time, leadTime -> lead_time, validDate -> time, percentileValue -> percentile, thresholdLowerLimit -> threshold_lower_limit, thresholdUpperLimit -> threshold_upper_limit.
    • Vertical coordinates derived from valueOfFirstFixedSurface / valueOfSecondFixedSurface and their corresponding typeOf*FixedSurface definitions. New coordinate names are generated from the surface definition (lowercased, spaces to underscores, punctuation removed). If the name already exists, a "_2" suffix is appended.
    • Possible dimension swaps: level -> <derived_vertical_coord> when present; and for probabilistic variables, threshold -> threshold_lower_limit or threshold -> threshold_upper_limit when typeOfProbability indicates the appropriate semantics.
    • Variable names lowercased; dataset- and variable-level attributes converted to snake_case (except GRIB section attributes which are normalized to grib...).
    • CF-adjacent metadata populated: standard_name and cell_methods are set via the shortname→CF lookup table.
    • Percent units normalized from "%" to "percent" on coordinates.
    • For precipitation type (PTYPE) thresholds, numeric codes are decoded to strings (GRIB2 Table 4.201) in relevant attrs/coords.
Notes
  • Precipitation type decoding uses GRIB2 Table 4.201 via tables.get_value_from_table(code, "4.201") and returns a NumPy array with np.dtypes.StringDType.
  • CF-related lookups are performed using tables.get_table("shortname_to_cf").
  • Vertical coordinate surface names are validated against VERTICAL_COORDINATE_SURFACES before promotion to coordinates.
Warnings

This function assumes the presence of certain GRIB-derived attributes on the first data variable (e.g., typeOfFirstFixedSurface, typeOfSecondFixedSurface, and possibly typeOfProbability). If these are absent or malformed, errors (e.g., KeyError) may occur.

Examples
>>> ds2 = parse_data_model(ds, "nws-viz")
>>> list(ds2.coords)
['forecast_reference_time', 'lead_time', 'time', 'percentile', ...]
class GribBackendEntrypoint(xarray.backends.common.BackendEntrypoint):
436class GribBackendEntrypoint(BackendEntrypoint):
437    """
438    xarray backend engine entrypoint for opening and decoding grib2 files.
439
440    .. warning::
441
442       This backend is experimental and the API/behavior may change without
443       backward compatibility.
444    """
445
446    def open_dataset(
447        self,
448        filename,
449        *,
450        drop_variables=None,
451        save_index=True,
452        filters: typing.Mapping[str, typing.Any] = dict(),
453        data_model=None
454    ):
455        """
456        Read and parse metadata from grib file.
457
458        Parameters
459        ----------
460        filename
461            GRIB2 file to be opened.
462        filters
463            Filter GRIB2 messages to single hypercube. Dict keys can be any
464            GRIB2 metadata attribute name.
465        data_model
466            Parse GRIB metadata following a defined data model comvention.
467
468        Returns
469        -------
470        open_dataset
471            Xarray dataset of grib2 messages.
472        """
473        with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f:
474            file_index = pd.DataFrame(f._index)
475            file_index = file_index.assign(msg=msgs_from_index(f._index))
476
477        # parse grib2io _index to dataframe and acquire non-geo possible dims
478        # (scalar coord when not dim due to squeeze) parse_grib_index applies
479        # filters to index and expands metadata based on product definition
480        # template number
481        file_index, dim_coords, attrs, coord_attrs = parse_grib_index(file_index, filters)
482
483        # Divide up records by variable
484        frames, cube, extra_geo = make_variables(file_index, filename, dim_coords)  # have this return var_attrs
485
486        # return empty dataset if no data
487        if frames is None:
488            return xr.Dataset()
489
490        # create dataframe and add datarrays without any coords
491        ds = xr.Dataset()
492        for var_df in frames:
493            da = build_da_without_coords(var_df, cube, filename, attrs)
494            ds[da.name] = da
495
496        # add coords and dataset meta
497        ds = assign_xr_meta(ds, frames, cube, dim_coords, extra_geo, coord_attrs)
498
499        if data_model is not None:
500            ds = parse_data_model(ds, data_model)
501
502        # assign attributes
503        ds.attrs['engine'] = 'grib2io'
504
505        return ds
506
507    def open_datatree(
508        self,
509        filename,
510        *,
511        drop_variables=None,
512        save_index=True,
513        filters: typing.Mapping[str, typing.Any] = None,
514        stack_vertical: bool = False,
515    ):
516        """
517        Open a GRIB2 file as an xarray DataTree.
518
519        Parameters
520        ----------
521        filename : str
522            Path to the GRIB2 file.
523        drop_variables : list, optional
524            List of variables to exclude.
525        filters : dict, optional
526            Filter criteria for GRIB2 messages.
527        stack_vertical : bool, optional
528            If True, organize the tree with vertical layers stacked in a single dataset.
529
530        Returns
531        -------
532        xarray.DataTree
533            A hierarchical DataTree representation of the GRIB2 data.
534        """
535        if not _HAS_DATATREE:
536            raise ImportError("xarray version does not support DataTree functionality.")
537
538        if filters is None:
539            filters = {}
540
541        # Open the file without any filters first to get all messages
542        with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f:
543            file_index = pd.DataFrame(f._index)
544            file_index = file_index.assign(msg=msgs_from_index(f._index))
545
546        # Build tree structure from GRIB messages with specified options
547        tree = build_datatree_from_grib(filename, file_index, filters, stack_vertical=stack_vertical)
548
549        # Put warning here so it is the last message from likely other Xarray warnings.
550        warnings.warn(
551            "grib2io’s xarray backend DataTree support is experimental. "
552            "The DataTree structure or attributes may change in future releases.",
553        UserWarning,
554        stacklevel=2,
555        )
556
557        return tree

xarray backend engine entrypoint for opening and decoding grib2 files.

This backend is experimental and the API/behavior may change without backward compatibility.

def open_dataset( self, filename, *, drop_variables=None, save_index=True, filters: Mapping[str, Any] = {}, data_model=None):
446    def open_dataset(
447        self,
448        filename,
449        *,
450        drop_variables=None,
451        save_index=True,
452        filters: typing.Mapping[str, typing.Any] = dict(),
453        data_model=None
454    ):
455        """
456        Read and parse metadata from grib file.
457
458        Parameters
459        ----------
460        filename
461            GRIB2 file to be opened.
462        filters
463            Filter GRIB2 messages to single hypercube. Dict keys can be any
464            GRIB2 metadata attribute name.
465        data_model
466            Parse GRIB metadata following a defined data model comvention.
467
468        Returns
469        -------
470        open_dataset
471            Xarray dataset of grib2 messages.
472        """
473        with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f:
474            file_index = pd.DataFrame(f._index)
475            file_index = file_index.assign(msg=msgs_from_index(f._index))
476
477        # parse grib2io _index to dataframe and acquire non-geo possible dims
478        # (scalar coord when not dim due to squeeze) parse_grib_index applies
479        # filters to index and expands metadata based on product definition
480        # template number
481        file_index, dim_coords, attrs, coord_attrs = parse_grib_index(file_index, filters)
482
483        # Divide up records by variable
484        frames, cube, extra_geo = make_variables(file_index, filename, dim_coords)  # have this return var_attrs
485
486        # return empty dataset if no data
487        if frames is None:
488            return xr.Dataset()
489
490        # create dataframe and add datarrays without any coords
491        ds = xr.Dataset()
492        for var_df in frames:
493            da = build_da_without_coords(var_df, cube, filename, attrs)
494            ds[da.name] = da
495
496        # add coords and dataset meta
497        ds = assign_xr_meta(ds, frames, cube, dim_coords, extra_geo, coord_attrs)
498
499        if data_model is not None:
500            ds = parse_data_model(ds, data_model)
501
502        # assign attributes
503        ds.attrs['engine'] = 'grib2io'
504
505        return ds

Read and parse metadata from grib file.

Parameters
  • filename: GRIB2 file to be opened.
  • filters: Filter GRIB2 messages to single hypercube. Dict keys can be any GRIB2 metadata attribute name.
  • data_model: Parse GRIB metadata following a defined data model comvention.
Returns
  • open_dataset: Xarray dataset of grib2 messages.
def open_datatree( self, filename, *, drop_variables=None, save_index=True, filters: Mapping[str, Any] = None, stack_vertical: bool = False):
507    def open_datatree(
508        self,
509        filename,
510        *,
511        drop_variables=None,
512        save_index=True,
513        filters: typing.Mapping[str, typing.Any] = None,
514        stack_vertical: bool = False,
515    ):
516        """
517        Open a GRIB2 file as an xarray DataTree.
518
519        Parameters
520        ----------
521        filename : str
522            Path to the GRIB2 file.
523        drop_variables : list, optional
524            List of variables to exclude.
525        filters : dict, optional
526            Filter criteria for GRIB2 messages.
527        stack_vertical : bool, optional
528            If True, organize the tree with vertical layers stacked in a single dataset.
529
530        Returns
531        -------
532        xarray.DataTree
533            A hierarchical DataTree representation of the GRIB2 data.
534        """
535        if not _HAS_DATATREE:
536            raise ImportError("xarray version does not support DataTree functionality.")
537
538        if filters is None:
539            filters = {}
540
541        # Open the file without any filters first to get all messages
542        with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f:
543            file_index = pd.DataFrame(f._index)
544            file_index = file_index.assign(msg=msgs_from_index(f._index))
545
546        # Build tree structure from GRIB messages with specified options
547        tree = build_datatree_from_grib(filename, file_index, filters, stack_vertical=stack_vertical)
548
549        # Put warning here so it is the last message from likely other Xarray warnings.
550        warnings.warn(
551            "grib2io’s xarray backend DataTree support is experimental. "
552            "The DataTree structure or attributes may change in future releases.",
553        UserWarning,
554        stacklevel=2,
555        )
556
557        return tree

Open a GRIB2 file as an xarray DataTree.

Parameters
  • filename (str): Path to the GRIB2 file.
  • drop_variables (list, optional): List of variables to exclude.
  • filters (dict, optional): Filter criteria for GRIB2 messages.
  • stack_vertical (bool, optional): If True, organize the tree with vertical layers stacked in a single dataset.
Returns
  • xarray.DataTree: A hierarchical DataTree representation of the GRIB2 data.
class GribBackendArray(xarray.backends.common.BackendArray):
560class GribBackendArray(BackendArray):
561
562    def __init__(self, array, lock):
563        self.array = array
564        self.shape = array.shape
565        self.dtype = np.dtype(array.dtype)
566        self.lock = lock
567
568    def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayLike:
569        return xr.core.indexing.explicit_indexing_adapter(
570            key,
571            self.shape,
572            indexing.IndexingSupport.BASIC,
573            self._raw_getitem,
574        )
575
576    def _raw_getitem(self, key: tuple):
577        """Implement thread safe access to data on disk."""
578        with self.lock:
579            return self.array[key]

Mixin class that extends a class that defines a shape property to one that also defines ndim, size and __len__.

GribBackendArray(array, lock)
562    def __init__(self, array, lock):
563        self.array = array
564        self.shape = array.shape
565        self.dtype = np.dtype(array.dtype)
566        self.lock = lock
array
shape
dtype
lock
def exclusive_slice_to_inclusive(item: slice):
582def exclusive_slice_to_inclusive(item: slice):
583    """
584    Convert a slice with exclusive stop to an inclusive slice.
585
586    If the slice has a step, the stop is reduced by the step, so that both
587    interpretations would yield the same result.
588
589    The means that [start, stop) is converted to [start, stop - step].
590
591    Parameters
592    ----------
593    item
594        The slice to convert.
595
596    Returns
597    -------
598    slice
599        The converted slice.
600    """
601    # return the None slice
602    if item.start is None and item.stop is None and item.step is None:
603        return item
604    if not isinstance(item, slice):
605        raise ValueError(f'item must be a slice; it was of type {type(item)}')
606    # if step is None, it's one
607    step = 1 if item.step is None else item.step
608    if item.stop < item.start or step < 1:
609        raise ValueError(f'slice {item} not accounted for')
610    # handle case where slice has one item
611    if abs(item.stop - item.start) == step:
612        return [item.start]
613    # other cases require reducing the stop by the step
614    s = slice(item.start, item.stop - step, step)
615    return s

Convert a slice with exclusive stop to an inclusive slice.

If the slice has a step, the stop is reduced by the step, so that both interpretations would yield the same result.

The means that [start, stop) is converted to [start, stop - step].

Parameters
  • item: The slice to convert.
Returns
  • slice: The converted slice.
class Validator:
618class Validator:
619    def __set_name__(self, owner, name):
620        self.private_name = f'_{name}'
621        self.name = name
622
623    def __get__(self, obj, objtype=None):
624        try:
625            value = getattr(obj, self.private_name)
626        except AttributeError:
627            value = None
628        return value
class PdIndex(Validator):
631class PdIndex(Validator):
632
633    def __set__(self, obj, value):
634        try:
635            value = pd.Index(value)
636        except TypeError:
637            value = pd.Index([value])
638        setattr(obj, self.private_name, value)
def array_safe_eq(a, b) -> bool:
659def array_safe_eq(a, b) -> bool:
660    """Check if a and b are equal, even if they are numpy arrays."""
661    if a is b:
662        return True
663    if hasattr(a, 'equals'):
664        return a.equals(b)
665    if hasattr(a, 'all') and hasattr(b, 'all'):
666        return a.shape == b.shape and (a == b).all()
667    if hasattr(a, 'all') or hasattr(b, 'all'):
668        return False
669    try:
670        return a == b
671    except TypeError:
672        return NotImplementedError

Check if a and b are equal, even if they are numpy arrays.

def dc_eq(dc1, dc2) -> bool:
675def dc_eq(dc1, dc2) -> bool:
676    """Check if two dataclasses which hold numpy arrays are equal."""
677    if dc1 is dc2:
678        return True
679    if dc1.__class__ is not dc2.__class__:
680        return NotImplementedError
681    t1 = astuple(dc1)
682    t2 = astuple(dc2)
683    return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))

Check if two dataclasses which hold numpy arrays are equal.

def coords_from_cube(cube) -> Dict[str, xarray.core.variable.Variable]:
686def coords_from_cube(cube) -> typing.Dict[str, xr.Variable]:
687    keys = list(cube.keys())
688    keys.remove('x')
689    keys.remove('y')
690    coords = dict()
691    for k in keys:
692        if k is not None:
693            if len(cube[k]) > 1:
694                coords[k] = xr.Variable(dims=k, data=cube[k], attrs=dict(grib_name=k))
695            elif len(cube[k]) == 1:
696                coords[k] = xr.Variable(dims=tuple(), data=cube[k][0], attrs=dict(grib_name=k))
697    return coords
@dataclass
class OnDiskArray:
700@dataclass
701class OnDiskArray:
702    file_name: str
703    index: pd.DataFrame = field(repr=False)
704    cube: dict = field(repr=False)
705    shape: typing.Tuple[int, ...] = field(init=False)
706    ndim: int = field(init=False)
707    geo_ndim: int = field(init=False)
708    dtype = 'float32'
709
710    def __post_init__(self):
711        # multiple grids not allowed so can just use first
712        geo_shape = (self.index.iloc[0].ny, self.index.iloc[0].nx)
713
714        self.geo_shape = geo_shape
715        self.geo_ndim = len(geo_shape)
716
717        if len(self.index) == 1:
718            self.shape = geo_shape
719        else:
720            if self.index.index.nlevels == 1:
721                self.shape = tuple([len(self.index.index)]) + geo_shape
722            else:
723                self.shape = tuple([len(i) for i in self.index.index.levels]) + geo_shape
724        self.ndim = len(self.shape)
725
726        cols = ['msg', 'sectionOffset']
727        self.index = self.index[cols]
728
729    def __getitem__(self, item) -> np.array:
730        # dimensions not in index are internal to tdlpack records; 2 dims for
731        # grids; 1 dim for stations
732
733        index_slicer = item[:-self.geo_ndim]
734        # maintain all multindex levels
735        index_slicer = tuple([[i] if isinstance(i, int) else i for i in index_slicer])
736
737        # pandas loc slicing is inclusive, therefore convert slices into
738        # explicit lists
739        index_slicer_inclusive = tuple([exclusive_slice_to_inclusive(
740            i) if isinstance(i, slice) else i for i in index_slicer])
741
742        # get records selected by item in new index dataframe
743        if len(index_slicer_inclusive) == 1:
744            index = self.index.loc[index_slicer_inclusive]
745        elif len(index_slicer_inclusive) > 1:
746            index = self.index.loc[index_slicer_inclusive, :]
747        else:
748            index = self.index
749        index = index.set_index(index.index)
750
751        # set miloc to new relative locations in sub array
752        index['miloc'] = list(
753            zip(*[index.index.unique(level=dim).get_indexer(index.index.get_level_values(dim)) for dim in index.index.names]))
754
755        if len(index_slicer_inclusive) == 1:
756            array_field_shape = tuple([len(index.index)]) + self.geo_shape
757        elif len(index_slicer_inclusive) > 1:
758            array_field_shape = index.index.levshape + self.geo_shape
759        else:
760            array_field_shape = self.geo_shape
761
762        array_field = np.full(array_field_shape, fill_value=np.nan, dtype="float32")
763
764        with open(self.file_name, mode='rb') as filehandle:
765            for key, row in index.iterrows():
766
767                bitmap_offset = None if pd.isna(row['sectionOffset'][6]) else int(row['sectionOffset'][6])
768                values = _data(filehandle, row.msg, bitmap_offset, row['sectionOffset'][7])
769
770                if len(index_slicer_inclusive) >= 1:
771                    array_field[row.miloc] = values
772                else:
773                    array_field = values
774
775        # handle geo dim slicing
776        array_field = array_field[(Ellipsis,) + item[-self.geo_ndim:]]
777
778        # squeeze array dimensions expressed as integer
779        for i, it in reversed(list(enumerate(item[: -self.geo_ndim]))):
780            if isinstance(it, int):
781                array_field = array_field[(slice(None, None, None),) * i + (0,)]
782
783        return array_field
OnDiskArray(file_name: str, index: pandas.core.frame.DataFrame, cube: dict)
file_name: str
index: pandas.core.frame.DataFrame
cube: dict
shape: Tuple[int, ...]
ndim: int
geo_ndim: int
dtype = 'float32'
def dims_to_shape(d) -> tuple:
786def dims_to_shape(d) -> tuple:
787    if 'nx' in d:
788        t = (d['ny'], d['nx'])
789    else:
790        t = (d['nsta'],)
791    return t
def filter_index(index, k, v):
794def filter_index(index, k, v):
795    if isinstance(v, slice):
796        index = index.set_index(k)
797        index = index.loc[v]
798        index = index.reset_index()
799    else:
800        label = (
801            v
802            if getattr(v, "ndim", 1) > 1  # vectorized-indexing
803            else _asarray_tuplesafe(v)
804        )
805        if label.ndim == 0:
806            # see https://github.com/pydata/xarray/pull/4292 for details
807            label_value = label[()] if label.dtype.kind in "mM" else label.item()
808            try:
809                indexer = pd.Index(index[k]).get_loc(label_value)
810                if isinstance(indexer, int):
811                    index = index.iloc[[indexer]]
812                else:
813                    index = index.iloc[indexer]
814            except KeyError:
815                index = index.iloc[[]]
816        else:
817            indexer = pd.Index(index[k]).get_indexer_for(np.ravel(v))
818            index = index.iloc[indexer[indexer >= 0]]
819
820    return index
def parse_grib_index(index: pandas.core.frame.DataFrame, filters: Mapping[str, Any] = {}):
823def parse_grib_index(
824    index: pd.DataFrame,
825    filters: typing.Mapping[str, typing.Any] = dict(),
826):
827    """
828    Apply filters.
829
830    Evaluate remaining dimensions based on pdtn and parse each out.
831
832    Parameters
833    ----------
834    index
835        Pandas DataFrame containing the GRIB2 message index.
836    filters
837        Filter GRIB2 messages to single hypercube. Dict keys can be any
838        GRIB2 metadata attribute name.
839
840    Returns
841    -------
842    index
843        Modified Pandas DataFrame with added GRIB2 metadata columns.
844    dim_coords
845        List of GRIB2 attributes that will be used for coordinates and/or dimensions.
846    attrs
847        Dict of metadata attributes (non-coordinates, non-geo)
848    """
849
850    # make a copy of filters, remove filters as they are applied
851    filters = copy(filters)
852
853    for k, v in filters.items():
854        if k not in index.columns:
855            kwarg = {k: index.msg.apply(lambda msg: getattr(msg, k))}
856            index = index.assign(**kwarg)
857        # adopt parts of xarray's sel logic  so that filters behave similarly
858        # allowed to filter to nothing to make empty dataset
859        index = filter_index(index, k, v)
860
861    if len(index) == 0:
862        return index, list()
863
864    dim_coords = dict()  # key=name of dim, value=list of coord names
865    attrs = dict()
866    coord_attrs = dict()
867
868    # expand index
869    index = index.assign(shortName=index.msg.apply(lambda msg: msg.shortName))
870    index = index.assign(nx=index.msg.apply(lambda msg: msg.nx))
871    index = index.assign(ny=index.msg.apply(lambda msg: msg.ny))
872    index = index.astype({'ny': 'int', 'nx': 'int'})
873
874    # apply common filters(to all definition templates) to reduce dataset to
875    # single cube
876    # ensure only one of each of the below exists after filters applied
877    required_uniques = [
878        "productDefinitionTemplateNumber",
879        "typeOfGeneratingProcess",
880        "typeOfFirstFixedSurface",
881        "typeOfSecondFixedSurface",
882    ]
883
884    def meta_check(index, attrs, meta):
885        """
886        add meta to the datframe index
887        check that there is a single type
888        add the type to attrs
889
890        returns index, attrs
891        """
892        index = index.assign(**{meta: index.msg.apply(lambda msg: getattr(msg, meta))})
893
894        unique = index[meta].unique()
895        if len(index[meta].unique()) > 1:
896            raise ValueError(f'filter to a single {meta}; found: {[str(i) for i in unique]}')
897        value = unique.item()
898        if type(value) == grib2io.templates.Grib2Metadata:
899            value = value.definition
900
901        # None is returned if no value found,
902        # check and change to string None
903        if value is None:
904            value = 'None'
905
906        attrs[meta] = value
907        return index, attrs
908
909    for meta in required_uniques:
910        index, attrs = meta_check(index, attrs, meta)
911
912    pdtn = index.productDefinitionTemplateNumber.iloc[0].value
913
914    # determine which non geo dimensions can be created from data by this point
915    # the index is filtered down to a single type for all required_uniques
916
917    # Dim Name     # matching dim_name for using this data as index coordinate
918    dim_coords["refDate"] = ["refDate"]
919    coord_attrs["refDate"] = dict(standard_name="forecast_reference_time")
920#   dim_coords["refDate"] = ["refDate", "hour"] # non dim name matching items in list are used as non-index coordinates
921
922    dim_coords["leadTime"] = ["leadTime"]
923    coord_attrs["leadTime"] = dict(standard_name="forecast_period")
924
925    if 'valueOfFirstFixedSurface' not in index.columns:
926        index = index.assign(valueOfFirstFixedSurface=index.msg.apply(lambda msg: msg.valueOfFirstFixedSurface))
927    if 'valueOfsecondFixedSurface' not in index.columns:
928        index = index.assign(valueOfSecondFixedSurface=index.msg.apply(lambda msg: msg.valueOfSecondFixedSurface))
929
930    # dim name api change, user could run ds = ds.swap_dims(fixedSurface="valueOfFirstFixedSurface")
931    index = index.assign(level=list(zip(index['valueOfFirstFixedSurface'], index['valueOfSecondFixedSurface'])))
932#   index = index.assign(level=index.msg.apply(lambda msg: msg.level))
933    # lack of "level" indeicates don't create extra index coordinate "level"
934    dim_coords["level"] = ["valueOfFirstFixedSurface", "valueOfSecondFixedSurface"]
935
936    # logic for parsing possible dims from specific product definition section
937
938    if pdtn in {5, 9}:
939
940        # Probability forecasts at a horizontal level or in a horizontal layer
941        # in a continuous or non-continuous time interval.  (see Template
942        # 4.9)
943        #       AVAILABLE_THRESHOLD = {
944        #           0: {'has_lower': True, 'has_upper': False},
945        #           1: {'has_lower': False, 'has_upper': True},
946        #           2: {'has_lower': True, 'has_upper': True},
947        #           3: {'has_lower': True, 'has_upper': False},
948        #           4: {'has_lower': False, 'has_upper': True},
949        #           5: {'has_lower': True, 'has_upper': False},
950        #       }
951
952        index, attrs = meta_check(index, attrs, "typeOfProbability")
953        if 'thresholdLowerLimit' not in index.columns:
954            index = index.assign(thresholdLowerLimit=index.msg.apply(lambda msg: msg.thresholdLowerLimit))
955        if 'thresholdUpperLimit' not in index.columns:
956            index = index.assign(thresholdUpperLimit=index.msg.apply(lambda msg: msg.thresholdUpperLimit))
957        if 'threshold' not in index.columns:
958            # using composite of lower and upper, but could use threshold string from grib2io as long as that is unique and based on lower and upper
959            index = index.assign(threshold=list(zip(index['thresholdLowerLimit'], index['thresholdUpperLimit'])))
960#           index = index.assign(threshold = index.msg.apply(lambda msg: msg.threshold))
961
962        # ommiting threshold results in no index being assigned for this possible dim
963        dim_coords["threshold"] = ["thresholdLowerLimit", "thresholdUpperLimit"]
964
965    if pdtn in {6, 10}:
966
967        # Percentile forecasts at a horizontal level or in a horizontal layer
968        # in a continuous or non-continuous time interval.  (see Template
969        # 4.10)
970        dim_coords["percentileValue"] = ["percentileValue"]
971        coord_attrs["percentileValue"] = dict(long_name='percentile', units='percent')
972
973    if pdtn in {8, 9, 10, 11, 12, 13, 14, 42, 43, 45, 46, 47, 61, 62, 63, 67, 68, 72, 73, 78, 79, 82, 83, 84, 85, 87, 91}:
974        dim_coords["duration"] = ["duration"]
975
976    if pdtn in {1, 11, 33, 34, 41, 43, 45, 47, 49, 54, 56, 58, 59, 63, 68, 77, 79, 81, 83, 84, 85, 92}:
977        dim_coords["perturbationNumber"] = ["perturbationNumber"]
978
979    if pdtn in {2,3,4,12,13,14}:
980        index, attrs = meta_check(index, attrs, 'typeOfDerivedForecast')
981
982    if pdtn in {8,15,42,46,62,67,72,78,82,1001,1002,1100,1101}:
983        index, attrs = meta_check(index, attrs, 'statisticalProcess')
984
985    # Finish logic by pdtn
986
987    for k, v in dim_coords.items():
988        for meta in v:
989            if meta not in index.columns:
990                index = index.assign(**{meta: index.msg.apply(lambda msg: getattr(msg, meta))})
991
992    return index, dim_coords, attrs, coord_attrs

Apply filters.

Evaluate remaining dimensions based on pdtn and parse each out.

Parameters
  • index: Pandas DataFrame containing the GRIB2 message index.
  • filters: Filter GRIB2 messages to single hypercube. Dict keys can be any GRIB2 metadata attribute name.
Returns
  • index: Modified Pandas DataFrame with added GRIB2 metadata columns.
  • dim_coords: List of GRIB2 attributes that will be used for coordinates and/or dimensions.
  • attrs: Dict of metadata attributes (non-coordinates, non-geo)
def open_datatree(filename, *, filters: Mapping[str, Any] = None, engine='grib2io'):
 996def open_datatree(filename, *, filters: typing.Mapping[str, typing.Any] = None, engine="grib2io"):
 997    """
 998    Open a GRIB2 file as an xarray DataTree.
 999
1000    Parameters
1001    ----------
1002    filename : str
1003        Path to the GRIB2 file.
1004    filters : dict, optional
1005        Filter criteria for GRIB2 messages.
1006    engine : str, optional
1007        Engine to use for opening the file, defaults to "grib2io".
1008
1009    Returns
1010    -------
1011    xarray.DataTree
1012        A hierarchical DataTree representation of the GRIB2 data.
1013    """
1014    if not _HAS_DATATREE:
1015        raise ImportError("xarray version does not support DataTree functionality.")
1016
1017    if filters is None:
1018        filters = {}
1019
1020    # Open the file without any filters first to get all messages
1021    with grib2io.open(filename, _xarray_backend=True) as f:
1022        file_index = pd.DataFrame(f._index)
1023
1024    # Create a DataTree root
1025    tree = xr.DataTree()
1026
1027    # Build tree structure from GRIB messages
1028    return build_datatree_from_grib(filename, file_index, filters)

Open a GRIB2 file as an xarray DataTree.

Parameters
  • filename (str): Path to the GRIB2 file.
  • filters (dict, optional): Filter criteria for GRIB2 messages.
  • engine (str, optional): Engine to use for opening the file, defaults to "grib2io".
Returns
  • xarray.DataTree: A hierarchical DataTree representation of the GRIB2 data.
def build_da_without_coords(index, cube, filename, attrs) -> xarray.core.dataarray.DataArray:
1031def build_da_without_coords(index, cube, filename, attrs) -> xr.DataArray:
1032    """
1033    Build a DataArray without coordinates from a cube of grib2 messages.
1034
1035    Parameters
1036    ----------
1037    index
1038        Index of cube.
1039    cube
1040        Cube of grib2 messages.
1041    filename
1042        Filename of grib2 file
1043    add_grib_section_attrs
1044        Include grib section arrays as dataArray attributes
1045
1046    Returns
1047    -------
1048    DataArray
1049        DataArray without coordinates
1050    """
1051
1052    dim_names = [k for k in cube.keys() if cube[k] is not None and len(cube[k]) > 1]
1053    constant_meta_names = [k for k in cube.keys() if cube[k] is None]
1054    dims = {k: len(cube[k]) for k in dim_names}
1055
1056    # guard against bad datarrays being formed
1057    dims_total = 1
1058    dims_to_filter = []
1059    for dim_name, dim_len, in dims.items():
1060        if dim_name not in {'x', 'y', 'station'}:
1061            dims_total *= dim_len
1062            dims_to_filter.append(dim_name)
1063
1064    # Check number of GRIB2 message indexed compared to non-X/Y
1065    # dimensions.
1066    if dims_total != len(index):
1067        raise ValueError(
1068            f"DataArray dimensions are not compatible with number of GRIB2 messages; DataArray has {dims_total} "
1069            f"and GRIB2 index has {len(index)}. Consider applying a filter for dimensions: {dims_to_filter}"
1070        )
1071
1072    data = OnDiskArray(filename, index, cube)
1073    lock = _LOCK
1074    data = GribBackendArray(data, lock)
1075    data = indexing.LazilyIndexedArray(data)
1076    if len(dim_names) != len(data.shape):
1077        raise ValueError(
1078            "different number of dimensions on data "
1079            f"and dims: {len(data.shape)} vs {len(dim_names)}\n"
1080            "Grib2 messages could not be formed into a data cube; "
1081            "It's possible extra messages exist along a non-accounted for dimension based on PDTN\n"
1082            "It might be possible to get around this by applying a filter on the non-accounted for dimension"
1083        )
1084    da = xr.DataArray(data, dims=dim_names)
1085
1086    da.encoding['original_shape'] = data.shape
1087
1088    da.encoding['preferred_chunks'] = {'y': -1, 'x': -1}
1089    msg1 = index.msg.iloc[0]
1090
1091    # plain language metadata is minimized
1092    # add grib section metadata
1093    da.attrs['GRIB2IO_section0'] = msg1.section0
1094    da.attrs['GRIB2IO_section1'] = msg1.section1
1095    da.attrs['GRIB2IO_section2'] = msg1.section2 if msg1.section2 else []
1096    da.attrs['GRIB2IO_section3'] = msg1.section3
1097    da.attrs['GRIB2IO_section4'] = msg1.section4
1098    da.attrs['GRIB2IO_section5'] = msg1.section5
1099    da.attrs['fullName'] = str(msg1.fullName)
1100    da.attrs['shortName'] = str(msg1.shortName)
1101    da.attrs['units'] = str(msg1.units)
1102    da.attrs['originatingCenter'] = str(msg1.originatingCenter.definition)
1103    da.attrs['originatingSubCenter'] = str(msg1.originatingSubCenter.definition)
1104
1105    # add master table
1106    da.attrs['masterTableInfo'] = str(msg1.masterTableInfo.definition)
1107
1108    da.name = index.shortName.iloc[0]
1109    for meta_name in constant_meta_names:
1110        if meta_name in index.columns:
1111            da.attrs[meta_name] = index[meta_name].iloc[0]
1112
1113    da.attrs.update(attrs)
1114
1115    return da

Build a DataArray without coordinates from a cube of grib2 messages.

Parameters
  • index: Index of cube.
  • cube: Cube of grib2 messages.
  • filename: Filename of grib2 file
  • add_grib_section_attrs: Include grib section arrays as dataArray attributes
Returns
  • DataArray: DataArray without coordinates
def assign_xr_meta(ds, frames, cube, non_geo_dims, extra_geo, coord_attrs):
1118def assign_xr_meta(ds, frames, cube, non_geo_dims, extra_geo, coord_attrs):
1119
1120    # assign coords from the cube; the cube prevents datarrays with
1121    # different shapes
1122    ds = ds.assign_coords(coords_from_cube(cube))
1123    # assign extra index associated coords
1124    df = frames[0]  # use first variable as they all have same shape and index metadata
1125    for dim_name, coord_names in non_geo_dims.items():
1126        retain_index_coord = False
1127        for name in coord_names:
1128            if name == dim_name:
1129                retain_index_coord = True
1130            else:
1131                if ds[dim_name].size == 1:
1132                    # for assigning scalar coords
1133                    coord_data = [df[name].unique().item()]
1134                    ds = ds.assign_coords({name: coord_data}).squeeze()
1135                else:
1136                    # "ValueError: can only convert an array of size 1 to a Python scalar" indicates the coord is not compatible with the index
1137                    coord_data = [df[df.index.get_level_values(f'{dim_name}_ix') == val][name].unique(
1138                    ).item() for val in range(ds[dim_name].size)]
1139                    coord = pd.Index(coord_data, name=dim_name)
1140                    ds = ds.assign_coords({name: coord})
1141        if not retain_index_coord:
1142            ds = ds.drop_vars(dim_name)
1143
1144    # assign extra geo coords
1145    ds = ds.assign_coords(extra_geo)
1146    # add crs data from first grib message to each data variable and the dataset
1147    geo_attrs = {
1148        'crs_wkt': CRS.from_dict(df.msg.iloc[0].projParameters).to_wkt(),
1149        'gridlengthXDirection': df.msg.iloc[0].gridlengthXDirection,
1150        'gridlengthYDirection': df.msg.iloc[0].gridlengthYDirection,
1151        'latitudeFirstGridpoint': df.msg.iloc[0].latitudeFirstGridpoint,
1152        'longitudeFirstGridpoint': df.msg.iloc[0].longitudeFirstGridpoint,
1153    }
1154    for data_var in ds.data_vars:
1155        ds[data_var].attrs.update(geo_attrs)
1156    ds.attrs.update(geo_attrs)
1157
1158    # add coordinate specific attributes
1159    for coord, attrs in coord_attrs.items():
1160        ds[coord].attrs.update(attrs)
1161
1162    # assign valid date coords
1163    try:
1164        ds = ds.assign_coords(dict(validDate=ds.coords['refDate']+ds.coords['leadTime']))
1165        ds.validDate.attrs['standard_name'] = 'time'
1166        ds.validDate.attrs['long_name'] = 'time'
1167    except Exception as e:
1168        warnings.warn(f'could not parse validTime: {e}')
1169
1170    # assign attributes
1171    ds.attrs['engine'] = 'grib2io'
1172
1173    return ds
def make_variables(index, f, non_geo_dims, allow_uneven_dims=False):
1176def make_variables(index, f, non_geo_dims, allow_uneven_dims=False):
1177    """
1178    Create an individual dataframe index and cube for each variable.
1179
1180    Parameters
1181    ----------
1182    index
1183        Index of cube.
1184    f
1185        ?
1186    non_geo_dims
1187        Dimensions not associated with the x,y grid
1188    allow_uneven_dims
1189        If True, allows uneven dimensions (used for DataTree creation)
1190
1191    Returns
1192    -------
1193    ordered_frames
1194        List of dataframes, one for each variable.
1195    cube
1196        Cube of grib2 messages.
1197    extra_geo
1198        Extra geographic coordinates.
1199    """
1200    # let shortName determine the variables
1201
1202    # set the index to the name
1203    index = index.set_index('shortName').sort_index()
1204    # return nothing if no data
1205    if index.empty:
1206        return None, None, None
1207
1208    # define the DimCube
1209    dims = copy(non_geo_dims)
1210
1211    ordered_meta = list(non_geo_dims.keys())
1212    cube = None
1213    ordered_frames = list()
1214    for key in index.index.unique():
1215        frame = index.loc[[key]]
1216        frame = frame.reset_index()
1217        # frame is a dataframe with all records for one variable
1218        c = dict()
1219        # for colname in frame.columns:
1220        for colname in ordered_meta:
1221            uniques = pd.Index(frame[colname]).unique()
1222            if len(uniques) > 1:
1223                c[colname] = uniques.sort_values()
1224            else:
1225                c[colname] = [uniques[0]]
1226
1227        dims = [k for k in ordered_meta if len(c[k]) > 1]
1228
1229        for dim in dims:
1230            if frame[dim].value_counts().nunique() > 1 and not allow_uneven_dims:
1231                raise ValueError(
1232                    f'uneven number of grib msgs associated with dimension: {dim}\n unique values for {dim}: {frame[dim].unique()} ')
1233
1234        if len(dims) >= 1:  # dims may be empty if no extra dims on top of x,y
1235            frame = frame.sort_values(dims)
1236            frame = frame.set_index(dims)
1237
1238        if cube:
1239            if cube != c and not allow_uneven_dims:
1240                raise ValueError(f'{cube},\n {c};\n cubes are not the same; filter to a single cube')
1241        else:
1242            cube = c
1243
1244        # miloc is multi-index integer location of msg in nd DataArray
1245        miloc = list(zip(*[frame.index.unique(level=dim).get_indexer(frame.index.get_level_values(dim))
1246                     for dim in dims]))
1247
1248        # set frame multi index
1249        if len(miloc) >= 1:  # miloc will be empty when no extra dims, thus no multiindex
1250            dim_ix = tuple([n+'_ix' for n in dims])
1251            frame = frame.set_index(pd.MultiIndex.from_tuples(miloc, names=dim_ix))
1252
1253        ordered_frames.append(frame)
1254
1255    # no variables
1256    if cube is None:
1257        cube = dict()
1258
1259    # check geography of data and assign to cube
1260    if len(index.ny.unique()) > 1 or len(index.nx.unique()) > 1:
1261        raise ValueError('multiple grids not accommodated')
1262    cube["y"] = range(int(index.ny.iloc[0]))
1263    cube["x"] = range(int(index.nx.iloc[0]))
1264
1265    extra_geo = None
1266    msg = index.msg.iloc[0]
1267
1268    # we want the lat lons; make them via accessing a record; we are assuming
1269    # all records are the same grid because they have the same shape;
1270    # may want a unique grid identifier from grib2io to avoid assuming this
1271    latitude, longitude = msg.latlons()
1272    latitude = xr.DataArray(latitude, dims=['y', 'x'])
1273    latitude.attrs['standard_name'] = 'latitude'
1274    latitude.attrs['units'] = 'degrees_north'
1275    longitude = xr.DataArray(longitude, dims=['y', 'x'])
1276    longitude.attrs['standard_name'] = 'longitude'
1277    longitude.attrs['units'] = 'degrees_east'
1278    extra_geo = dict(latitude=latitude, longitude=longitude)
1279
1280    return ordered_frames, cube, extra_geo

Create an individual dataframe index and cube for each variable.

Parameters
  • index: Index of cube.
  • f: ?
  • non_geo_dims: Dimensions not associated with the x,y grid
  • allow_uneven_dims: If True, allows uneven dimensions (used for DataTree creation)
Returns
  • ordered_frames: List of dataframes, one for each variable.
  • cube: Cube of grib2 messages.
  • extra_geo: Extra geographic coordinates.
def interp_nd( a, *, method, grid_def_in, grid_def_out, method_options=None, num_threads=1):
1283def interp_nd(a, *, method, grid_def_in, grid_def_out, method_options=None, num_threads=1):
1284    front_shape = a.shape[:-2]
1285    a = a.reshape(-1, a.shape[-2], a.shape[-1])
1286    a = grib2io.interpolate(a, method, grid_def_in, grid_def_out, method_options=method_options,
1287                            num_threads=num_threads)
1288    a = a.reshape(front_shape + (a.shape[-2], a.shape[-1]))
1289    return a
def interp_nd_stations( a, *, method, grid_def_in, lats, lons, method_options=None, num_threads=1):
1292def interp_nd_stations(a, *, method, grid_def_in, lats, lons, method_options=None, num_threads=1):
1293    front_shape = a.shape[:-2]
1294    a = a.reshape(-1, a.shape[-2], a.shape[-1])
1295    a = grib2io.interpolate_to_stations(a, method, grid_def_in, lats, lons, method_options=method_options,
1296                                        num_threads=num_threads)
1297    a = a.reshape(front_shape + (len(lats),))
1298    return a
@xr.register_dataset_accessor('grib2io')
class Grib2ioDataSet:
1301@xr.register_dataset_accessor("grib2io")
1302class Grib2ioDataSet:
1303
1304    def __init__(self, xarray_obj):
1305        self._obj = xarray_obj
1306
1307    def griddef(self):
1308        return Grib2GridDef.from_section3(self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3'])
1309
1310    def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.Dataset:
1311        # see interp method of class Grib2ioDataArray
1312        da = self._obj.to_array()
1313        da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3']
1314        da = da.grib2io.interp(method, grid_def_out, method_options=method_options,
1315                               num_threads=num_threads)
1316        ds = da.to_dataset(dim='variable')
1317        return ds
1318
1319    def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.Dataset:
1320        # see interp_to_stations method of class Grib2ioDataArray
1321        da = self._obj.to_array()
1322        da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3']
1323        da = da.grib2io.interp_to_stations(method, calls, lats, lons, method_options=method_options,
1324                                           num_threads=num_threads)
1325        ds = da.to_dataset(dim='variable')
1326        return ds
1327
1328    def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
1329        """
1330        Write a DataSet to a grib2 file.
1331
1332        Parameters
1333        ----------
1334        filename
1335            Name of the grib2 file to write to.
1336        mode: {"x", "w", "a"}, optional, default="x"
1337            Persistence mode
1338
1339            | mode | Description                       |
1340            | :---:| :---:                             |
1341            | 'x'  | create (fail if exists)           |
1342            | 'w'  | create (overwrite if exists)      |
1343            | 'a'  | append (create if does not exist) |
1344
1345        """
1346        ds = self._obj
1347
1348        for shortName in sorted(ds):
1349            # make a DataArray from the "Data Variables" in the DataSet
1350            da = ds[shortName]
1351
1352            da.grib2io.to_grib2(filename, mode=mode)
1353            mode = "a"
1354
1355    def update_attrs(self, **kwargs):
1356        """
1357        Raises an error because Datasets don't have a .attrs attribute.
1358
1359        Parameters
1360        ----------
1361        attrs
1362            Attributes to update.
1363        """
1364        raise ValueError(
1365            f"Datasets do not have a .attrs attribute; use .grib2io.update_attrs({kwargs}) on a DataArray instead."
1366        )
1367
1368    def subset(self, lats, lons) -> xr.Dataset:
1369        """
1370        Subset the DataSet to a region defined by latitudes and longitudes.
1371
1372        Parameters
1373        ----------
1374        lats
1375            Latitude bounds of the region.
1376        lons
1377            Longitude bounds of the region.
1378
1379        Returns
1380        -------
1381        subset
1382            DataSet subset to the region.
1383        """
1384        ds = self._obj
1385
1386        newds = xr.Dataset()
1387        for shortName in ds:
1388            newds[shortName] = ds[shortName].grib2io.subset(lats, lons).copy()
1389
1390        return newds
Grib2ioDataSet(xarray_obj)
1304    def __init__(self, xarray_obj):
1305        self._obj = xarray_obj
def griddef(self):
1307    def griddef(self):
1308        return Grib2GridDef.from_section3(self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3'])
def interp( self, method, grid_def_out, method_options=None, num_threads=1) -> xarray.core.dataset.Dataset:
1310    def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.Dataset:
1311        # see interp method of class Grib2ioDataArray
1312        da = self._obj.to_array()
1313        da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3']
1314        da = da.grib2io.interp(method, grid_def_out, method_options=method_options,
1315                               num_threads=num_threads)
1316        ds = da.to_dataset(dim='variable')
1317        return ds
def interp_to_stations( self, method, calls, lats, lons, method_options=None, num_threads=1) -> xarray.core.dataset.Dataset:
1319    def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.Dataset:
1320        # see interp_to_stations method of class Grib2ioDataArray
1321        da = self._obj.to_array()
1322        da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3']
1323        da = da.grib2io.interp_to_stations(method, calls, lats, lons, method_options=method_options,
1324                                           num_threads=num_threads)
1325        ds = da.to_dataset(dim='variable')
1326        return ds
def to_grib2(self, filename, mode: Literal['x', 'w', 'a'] = 'x'):
1328    def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
1329        """
1330        Write a DataSet to a grib2 file.
1331
1332        Parameters
1333        ----------
1334        filename
1335            Name of the grib2 file to write to.
1336        mode: {"x", "w", "a"}, optional, default="x"
1337            Persistence mode
1338
1339            | mode | Description                       |
1340            | :---:| :---:                             |
1341            | 'x'  | create (fail if exists)           |
1342            | 'w'  | create (overwrite if exists)      |
1343            | 'a'  | append (create if does not exist) |
1344
1345        """
1346        ds = self._obj
1347
1348        for shortName in sorted(ds):
1349            # make a DataArray from the "Data Variables" in the DataSet
1350            da = ds[shortName]
1351
1352            da.grib2io.to_grib2(filename, mode=mode)
1353            mode = "a"

Write a DataSet to a grib2 file.

Parameters
  • filename: Name of the grib2 file to write to.
  • mode ({"x", "w", "a"}, optional, default="x"): Persistence mode

    mode Description
    'x' create (fail if exists)
    'w' create (overwrite if exists)
    'a' append (create if does not exist)
def update_attrs(self, **kwargs):
1355    def update_attrs(self, **kwargs):
1356        """
1357        Raises an error because Datasets don't have a .attrs attribute.
1358
1359        Parameters
1360        ----------
1361        attrs
1362            Attributes to update.
1363        """
1364        raise ValueError(
1365            f"Datasets do not have a .attrs attribute; use .grib2io.update_attrs({kwargs}) on a DataArray instead."
1366        )

Raises an error because Datasets don't have a .attrs attribute.

Parameters
  • attrs: Attributes to update.
def subset(self, lats, lons) -> xarray.core.dataset.Dataset:
1368    def subset(self, lats, lons) -> xr.Dataset:
1369        """
1370        Subset the DataSet to a region defined by latitudes and longitudes.
1371
1372        Parameters
1373        ----------
1374        lats
1375            Latitude bounds of the region.
1376        lons
1377            Longitude bounds of the region.
1378
1379        Returns
1380        -------
1381        subset
1382            DataSet subset to the region.
1383        """
1384        ds = self._obj
1385
1386        newds = xr.Dataset()
1387        for shortName in ds:
1388            newds[shortName] = ds[shortName].grib2io.subset(lats, lons).copy()
1389
1390        return newds

Subset the DataSet to a region defined by latitudes and longitudes.

Parameters
  • lats: Latitude bounds of the region.
  • lons: Longitude bounds of the region.
Returns
  • subset: DataSet subset to the region.
@xr.register_dataarray_accessor('grib2io')
class Grib2ioDataArray:
1393@xr.register_dataarray_accessor("grib2io")
1394class Grib2ioDataArray:
1395
1396    def __init__(self, xarray_obj):
1397        self._obj = xarray_obj
1398
1399    def griddef(self):
1400        return Grib2GridDef.from_section3(self._obj.attrs['GRIB2IO_section3'])
1401
1402    def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.DataArray:
1403        """
1404        Perform grid spatial interpolation.
1405
1406        Uses the [NCEPLIBS-ip library](https://github.com/NOAA-EMC/NCEPLIBS-ip).
1407
1408        Parameters
1409        ----------
1410        method
1411            Interpolate method to use. This can either be an integer or string
1412            using the following mapping:
1413
1414            | Interpolate Scheme | Integer Value |
1415            | :---:              | :---:         |
1416            | 'bilinear'         | 0             |
1417            | 'bicubic'          | 1             |
1418            | 'neighbor'         | 2             |
1419            | 'budget'           | 3             |
1420            | 'spectral'         | 4             |
1421            | 'neighbor-budget'  | 6             |
1422        grid_def_out
1423            Grib2GridDef object of the output grid.
1424        method_options : list of ints, optional
1425            Interpolation options. See the NCEPLIBS-ip documentation for
1426            more information on how these are used.
1427        num_threads : int, optional
1428            Number of OpenMP threads to use for interpolation. The default
1429            value is 1. If grib2io_interp was not built with OpenMP, then
1430            this keyword argument and value will have no impact.
1431
1432        Returns
1433        -------
1434        interp
1435            DataSet interpolated to new grid definition.  The attribute
1436            GRIB2IO_section3 is replaced with the section3 array from the new
1437            grid definition.
1438        """
1439        da = self._obj
1440        # ensure that y, x are rightmost dims; they should be if opening with
1441        # grib2io engine
1442
1443        # gdtn and gdt is not the entirety of the new s3
1444        npoints = grid_def_out.npoints
1445        s3_new = np.array([0, npoints, 0, 0, grid_def_out.gdtn] + list(grid_def_out.gdt))
1446
1447        # make new lat lons
1448        lats, lons = Grib2Message(section3=s3_new, pdtn=0, drtn=0).grid()
1449        latitude = xr.DataArray(lats, dims=['y', 'x'])
1450        longitude = xr.DataArray(lons, dims=['y', 'x'])
1451
1452        # create new coords
1453        new_coords = dict(da.coords)
1454        del new_coords['latitude']
1455        del new_coords['longitude']
1456        new_coords['longitude'] = longitude
1457        new_coords['latitude'] = latitude
1458
1459        # make grid def in from section3 on da.attrs
1460        grid_def_in = self.griddef()
1461
1462        if da.chunks is None:
1463            data = interp_nd(da.data, method=method, grid_def_in=grid_def_in,
1464                             grid_def_out=grid_def_out,
1465                             method_options=method_options, num_threads=num_threads)
1466        else:
1467            import dask
1468            front_shape = da.shape[:-2]
1469            data = da.data.map_blocks(interp_nd, method=method, grid_def_in=grid_def_in,
1470                                      grid_def_out=grid_def_out, method_options=method_options,
1471                                      chunks=da.chunks[:-2]+latitude.shape, dtype=da.dtype)
1472
1473        new_da = xr.DataArray(data, dims=da.dims, coords=new_coords, attrs=da.attrs)
1474
1475        new_da.attrs['GRIB2IO_section3'] = s3_new
1476        new_da.name = da.name
1477        return new_da
1478
1479    def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.DataArray:
1480        """
1481        Perform spatial interpolation to station points.
1482
1483        Parameters
1484        ----------
1485        method
1486            Interpolate method to use. This can either be an integer or string
1487            using the following mapping:
1488
1489            | Interpolate Scheme | Integer Value |
1490            | :---:              | :---:         |
1491            | 'bilinear'         | 0             |
1492            | 'bicubic'          | 1             |
1493            | 'neighbor'         | 2             |
1494            | 'budget'           | 3             |
1495            | 'spectral'         | 4             |
1496            | 'neighbor-budget'  | 6             |
1497
1498        calls
1499            Station calls used for labeling new station index coordinate
1500        lats
1501            Latitudes of the station points.
1502        lons
1503            Longitudes of the station points.
1504
1505        Returns
1506        -------
1507        interp_to_stations
1508            DataArray interpolated to lat and lon locations and labeled with
1509            dimension and coordinate 'station'. (..., y, x) -> (..., station)
1510        """
1511        da = self._obj
1512        # TODO ensure that y, x are rightmost dims; they should be if opening
1513        # with grib2io engine
1514
1515        calls = np.asarray(calls)
1516        lats = np.asarray(lats)
1517        lons = np.asarray(lons)
1518        latitude = xr.DataArray(lats, dims=['station'])
1519        longitude = xr.DataArray(lons, dims=['station'])
1520
1521        # create new coords
1522        new_coords = dict(da.coords)
1523        del new_coords['latitude']
1524        del new_coords['longitude']
1525        new_coords['longitude'] = longitude
1526        new_coords['latitude'] = latitude
1527        new_coords['station'] = calls
1528
1529        new_dims = da.dims[:-2] + ('station',)
1530
1531        # make grid def in from section3 on da attrs
1532        grid_def_in = self.griddef()
1533
1534        if da.chunks is None:
1535            data = interp_nd_stations(da.data, method=method, grid_def_in=grid_def_in, lats=lats,
1536                                      lons=lons, method_options=method_options, num_threads=num_threads)
1537        else:
1538            import dask
1539            front_shape = da.shape[:-1]
1540            data = da.data.map_blocks(interp_nd_stations, method=method, grid_def_in=grid_def_in,
1541                                      lats=lats, lons=lons, method_options=method_options,
1542                                      drop_axis=-1, chunks=da.chunks[:-2]+latitude.shape,
1543                                      dtype=da.dtype)
1544
1545        new_da = xr.DataArray(data, dims=new_dims, coords=new_coords, attrs=da.attrs)
1546
1547        new_da.name = da.name
1548        return new_da
1549
1550    def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
1551        """
1552        Write a DataArray to a grib2 file.
1553
1554        Parameters
1555        ----------
1556        filename
1557            Name of the grib2 file to write to.
1558        mode: {"x", "w", "a"}, optional, default="x"
1559            Persistence mode
1560
1561            +------+-----------------------------------+
1562            | mode | Description                       |
1563            +======+===================================+
1564            | x    | create (fail if exists)           |
1565            +------+-----------------------------------+
1566            | w    | create (overwrite if exists)      |
1567            +------+-----------------------------------+
1568            | a    | append (create if does not exist) |
1569            +------+-----------------------------------+
1570
1571        """
1572        da = self._obj.copy(deep=True)
1573
1574        coords_keys = sorted(da.coords.keys())
1575        coords_keys = [k for k in coords_keys if k in AVAILABLE_NON_GEO_COORDS]
1576
1577        # If there are dimension coordinates, the DataArray is a hypercube of
1578        # grib2 messages.
1579
1580        # Create `indexes` which is a list of lists of dictionaries for all
1581        # dimension coordinates. Each dictionary key is the dimension
1582        # coordinate name and the value is a list of the dimension coordinate
1583        # values.  This allows for easy iteration over all possible grib2
1584        # messages in the DataArray by using itertools.product.
1585        #
1586        # For example:
1587        # indexes = [
1588        #     [
1589        #         {"leadTime": 9},
1590        #         {"leadTime": 12},
1591        #     ],
1592        #     [
1593        #         {"valueOfFirstFixedSurface": 900},
1594        #         {"valueOfFirstFixedSurface": 925},
1595        #         {"valueOfFirstFixedSurface": 950},
1596        #     ],
1597        # ]
1598
1599        # assign loc indexes to dimensions without indexes for uniform selection by name
1600        loc_indexes = list()
1601        for dim in da.dims:
1602            if dim not in da.indexes:
1603                da = da.assign_coords({dim: range(da[dim].size)})
1604                loc_indexes.append(dim)
1605
1606        indexes = []
1607        for index in [i for i in AVAILABLE_NON_GEO_DIMS if i in da.dims]:
1608            values = da.coords[index].values
1609            if len(values) != len(set(values)):
1610                raise ValueError(
1611                    f"Dimension coordinate '{index}' has duplicate values, but to_grib2 requires unique values to find each GRIB2 message in the DataArray."
1612                )
1613            listeach = [{index: value} for value in sorted(values)]
1614            indexes.append(listeach)
1615
1616        # If `dim_coords` is [], then the DataArray is a single grib2 message and
1617        # itertools.product(*dim_coords) will run once with `selectors = ()`.
1618        for selectors in itertools.product(*indexes):
1619            # Need to find the correct data in the DataArray based on the
1620            # dimension coordinates.
1621            filters = {k: v for d in selectors for k, v in d.items()}
1622
1623            # If `filters` is {}, then the DataArray is a single grib2 message
1624            # and da.sel(indexers={}) returns the DataArray.
1625            selected = da.sel(indexers=filters)
1626
1627            newmsg = Grib2Message(
1628                selected.attrs["GRIB2IO_section0"],
1629                selected.attrs["GRIB2IO_section1"],
1630                selected.attrs["GRIB2IO_section2"],
1631                selected.attrs["GRIB2IO_section3"],
1632                selected.attrs["GRIB2IO_section4"],
1633                selected.attrs["GRIB2IO_section5"],
1634            )
1635            newmsg.data = np.array(selected.data)
1636
1637            # For dimension coordinates, set the grib2 message metadata to the
1638            # dimension coordinate value.
1639            for index, value in filters.items():
1640                if index not in loc_indexes:
1641                    setattr(newmsg, index, value)
1642
1643            # For non-dimension coordinates, set the grib2 message metadata to
1644            # the DataArray coordinate value.
1645            for index in [i for i in coords_keys if i not in da.dims]:
1646                setattr(newmsg, index, selected.coords[index].values)
1647
1648            # Set section 5 attributes to the da.encoding dictionary.
1649            for key, value in selected.encoding.items():
1650                if key in ["dtype", "chunks", "original_shape"]:
1651                    continue
1652                setattr(newmsg, key, value)
1653
1654            # write the message to file
1655            with grib2io.open(filename, mode=mode) as f:
1656                f.write(newmsg)
1657            mode = "a"
1658
1659    def update_attrs(self, **kwargs):
1660        """
1661        Update many of the attributes of the DataArray.
1662
1663        Parameters
1664        ----------
1665        **kwargs
1666            Attributes to update.  This can include many of the GRIB2IO message
1667            attributes that you can find when you print a GRIB2IO message. For
1668            conflicting updates, the last keyword will be used.
1669
1670            +-----------------------+------------------------------------------+
1671            | kwargs                | Description                              |
1672            +=======================+==========================================+
1673            | shortName="VTMP"      | Set shortName to "VTMP", along with      |
1674            |                       | appropriate discipline,                  |
1675            |                       | parameterCategory, parameterNumber,      |
1676            |                       | fullName and units.                      |
1677            +-----------------------+------------------------------------------+
1678            | discipline=0,         | Set shortName, discipline,               |
1679            | parameterCategory=0,  | parameterCategory, parameterNumber,      |
1680            | parameterNumber=1     | fullName and units appropriate for       |
1681            |                       | "Virtual Temperature".                   |
1682            +-----------------------+------------------------------------------+
1683            | discipline=0,         | Conflicting keywords but                 |
1684            | parameterCategory=0,  | 'shortName="TMP"' wins.  Set shortName,  |
1685            | parameterNumber=1,    | discipline, parameterCategory,           |
1686            | shortName="TMP"       | parameterNumber, fullName and units      |
1687            |                       | appropriate for "Temperature".           |
1688            +-----------------------+------------------------------------------+
1689
1690        Returns
1691        -------
1692        DataArray
1693            DataArray with updated attributes.
1694        """
1695        da = self._obj.copy(deep=True)
1696
1697        newmsg = Grib2Message(
1698            da.attrs["GRIB2IO_section0"],
1699            da.attrs["GRIB2IO_section1"],
1700            da.attrs["GRIB2IO_section2"],
1701            da.attrs["GRIB2IO_section3"],
1702            da.attrs["GRIB2IO_section4"],
1703            da.attrs["GRIB2IO_section5"],
1704        )
1705
1706        coords_keys = [
1707            k
1708            for k in da.coords.keys()
1709            if k in AVAILABLE_NON_GEO_COORDS
1710        ]
1711
1712        for grib2_name, value in kwargs.items():
1713            if grib2_name == "gridDefinitionTemplateNumber":
1714                raise ValueError(
1715                    "The gridDefinitionTemplateNumber attribute cannot be updated.  The best way to change to a different grid is to interpolate the data to a new grid using the grib2io interpolate functions."
1716                )
1717            if grib2_name == "productDefinitionTemplateNumber":
1718                raise ValueError(
1719                    "The productDefinitionTemplateNumber attribute cannot be updated."
1720                )
1721            if grib2_name == "dataRepresentationTemplateNumber":
1722                raise ValueError(
1723                    "The dataRepresentationTemplateNumber attribute cannot be updated."
1724                )
1725            if grib2_name in coords_keys:
1726                warnings.warn(
1727                    f"Skipping attribute '{grib2_name}' because it is a coordinate. Use da.assign_coords() to change coordinate values."
1728                )
1729                continue
1730            if hasattr(newmsg, grib2_name):
1731                setattr(newmsg, grib2_name, value)
1732            else:
1733                warnings.warn(
1734                    f"Skipping attribute '{grib2_name}' because it is not a valid GRIB2 attribute for this message and cannot be updated."
1735                )
1736                continue
1737
1738        da.attrs["GRIB2IO_section0"] = newmsg.section0
1739        da.attrs["GRIB2IO_section1"] = newmsg.section1
1740        da.attrs["GRIB2IO_section2"] = newmsg.section2 or []
1741        da.attrs["GRIB2IO_section3"] = newmsg.section3
1742        da.attrs["GRIB2IO_section4"] = newmsg.section4
1743        da.attrs["GRIB2IO_section5"] = newmsg.section5
1744        da.attrs["fullName"] = newmsg.fullName
1745        da.attrs["shortName"] = newmsg.shortName
1746        da.attrs["units"] = newmsg.units
1747
1748        return da
1749
1750    def subset(self, lats, lons) -> xr.DataArray:
1751        """
1752        Subset the DataArray to a region defined by latitudes and longitudes.
1753
1754        Parameters
1755        ----------
1756        lats
1757            Latitude bounds of the region.
1758        lons
1759            Longitude bounds of the region.
1760
1761        Returns
1762        -------
1763        subset
1764            DataArray subset to the region.
1765        """
1766        da = self._obj.copy(deep=True)
1767
1768        newmsg = Grib2Message(
1769            da.attrs["GRIB2IO_section0"],
1770            da.attrs["GRIB2IO_section1"],
1771            da.attrs["GRIB2IO_section2"],
1772            da.attrs["GRIB2IO_section3"],
1773            da.attrs["GRIB2IO_section4"],
1774            da.attrs["GRIB2IO_section5"],
1775        )
1776
1777        newmsg.data = np.zeros((newmsg.ny, newmsg.nx), dtype=np.float32)
1778
1779        newmsg = newmsg.subset(lats, lons)
1780
1781        da.attrs["GRIB2IO_section3"] = newmsg.section3
1782
1783        mask_lat = (da.latitude >= newmsg.latitudeLastGridpoint) & (
1784            da.latitude <= newmsg.latitudeFirstGridpoint
1785        )
1786        mask_lon = (da.longitude >= newmsg.longitudeFirstGridpoint) & (
1787            da.longitude <= newmsg.longitudeLastGridpoint
1788        )
1789
1790        del newmsg
1791
1792        return da.where((mask_lon & mask_lat).compute(), drop=True)
Grib2ioDataArray(xarray_obj)
1396    def __init__(self, xarray_obj):
1397        self._obj = xarray_obj
def griddef(self):
1399    def griddef(self):
1400        return Grib2GridDef.from_section3(self._obj.attrs['GRIB2IO_section3'])
def interp( self, method, grid_def_out, method_options=None, num_threads=1) -> xarray.core.dataarray.DataArray:
1402    def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.DataArray:
1403        """
1404        Perform grid spatial interpolation.
1405
1406        Uses the [NCEPLIBS-ip library](https://github.com/NOAA-EMC/NCEPLIBS-ip).
1407
1408        Parameters
1409        ----------
1410        method
1411            Interpolate method to use. This can either be an integer or string
1412            using the following mapping:
1413
1414            | Interpolate Scheme | Integer Value |
1415            | :---:              | :---:         |
1416            | 'bilinear'         | 0             |
1417            | 'bicubic'          | 1             |
1418            | 'neighbor'         | 2             |
1419            | 'budget'           | 3             |
1420            | 'spectral'         | 4             |
1421            | 'neighbor-budget'  | 6             |
1422        grid_def_out
1423            Grib2GridDef object of the output grid.
1424        method_options : list of ints, optional
1425            Interpolation options. See the NCEPLIBS-ip documentation for
1426            more information on how these are used.
1427        num_threads : int, optional
1428            Number of OpenMP threads to use for interpolation. The default
1429            value is 1. If grib2io_interp was not built with OpenMP, then
1430            this keyword argument and value will have no impact.
1431
1432        Returns
1433        -------
1434        interp
1435            DataSet interpolated to new grid definition.  The attribute
1436            GRIB2IO_section3 is replaced with the section3 array from the new
1437            grid definition.
1438        """
1439        da = self._obj
1440        # ensure that y, x are rightmost dims; they should be if opening with
1441        # grib2io engine
1442
1443        # gdtn and gdt is not the entirety of the new s3
1444        npoints = grid_def_out.npoints
1445        s3_new = np.array([0, npoints, 0, 0, grid_def_out.gdtn] + list(grid_def_out.gdt))
1446
1447        # make new lat lons
1448        lats, lons = Grib2Message(section3=s3_new, pdtn=0, drtn=0).grid()
1449        latitude = xr.DataArray(lats, dims=['y', 'x'])
1450        longitude = xr.DataArray(lons, dims=['y', 'x'])
1451
1452        # create new coords
1453        new_coords = dict(da.coords)
1454        del new_coords['latitude']
1455        del new_coords['longitude']
1456        new_coords['longitude'] = longitude
1457        new_coords['latitude'] = latitude
1458
1459        # make grid def in from section3 on da.attrs
1460        grid_def_in = self.griddef()
1461
1462        if da.chunks is None:
1463            data = interp_nd(da.data, method=method, grid_def_in=grid_def_in,
1464                             grid_def_out=grid_def_out,
1465                             method_options=method_options, num_threads=num_threads)
1466        else:
1467            import dask
1468            front_shape = da.shape[:-2]
1469            data = da.data.map_blocks(interp_nd, method=method, grid_def_in=grid_def_in,
1470                                      grid_def_out=grid_def_out, method_options=method_options,
1471                                      chunks=da.chunks[:-2]+latitude.shape, dtype=da.dtype)
1472
1473        new_da = xr.DataArray(data, dims=da.dims, coords=new_coords, attrs=da.attrs)
1474
1475        new_da.attrs['GRIB2IO_section3'] = s3_new
1476        new_da.name = da.name
1477        return new_da

Perform grid spatial interpolation.

Uses the NCEPLIBS-ip library.

Parameters
  • method: Interpolate method to use. This can either be an integer or string using the following mapping:
Interpolate Scheme Integer Value
'bilinear' 0
'bicubic' 1
'neighbor' 2
'budget' 3
'spectral' 4
'neighbor-budget' 6

  • grid_def_out: Grib2GridDef object of the output grid.
  • method_options (list of ints, optional): Interpolation options. See the NCEPLIBS-ip documentation for more information on how these are used.
  • num_threads (int, optional): Number of OpenMP threads to use for interpolation. The default value is 1. If grib2io_interp was not built with OpenMP, then this keyword argument and value will have no impact.
Returns
  • interp: DataSet interpolated to new grid definition. The attribute GRIB2IO_section3 is replaced with the section3 array from the new grid definition.
def interp_to_stations( self, method, calls, lats, lons, method_options=None, num_threads=1) -> xarray.core.dataarray.DataArray:
1479    def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.DataArray:
1480        """
1481        Perform spatial interpolation to station points.
1482
1483        Parameters
1484        ----------
1485        method
1486            Interpolate method to use. This can either be an integer or string
1487            using the following mapping:
1488
1489            | Interpolate Scheme | Integer Value |
1490            | :---:              | :---:         |
1491            | 'bilinear'         | 0             |
1492            | 'bicubic'          | 1             |
1493            | 'neighbor'         | 2             |
1494            | 'budget'           | 3             |
1495            | 'spectral'         | 4             |
1496            | 'neighbor-budget'  | 6             |
1497
1498        calls
1499            Station calls used for labeling new station index coordinate
1500        lats
1501            Latitudes of the station points.
1502        lons
1503            Longitudes of the station points.
1504
1505        Returns
1506        -------
1507        interp_to_stations
1508            DataArray interpolated to lat and lon locations and labeled with
1509            dimension and coordinate 'station'. (..., y, x) -> (..., station)
1510        """
1511        da = self._obj
1512        # TODO ensure that y, x are rightmost dims; they should be if opening
1513        # with grib2io engine
1514
1515        calls = np.asarray(calls)
1516        lats = np.asarray(lats)
1517        lons = np.asarray(lons)
1518        latitude = xr.DataArray(lats, dims=['station'])
1519        longitude = xr.DataArray(lons, dims=['station'])
1520
1521        # create new coords
1522        new_coords = dict(da.coords)
1523        del new_coords['latitude']
1524        del new_coords['longitude']
1525        new_coords['longitude'] = longitude
1526        new_coords['latitude'] = latitude
1527        new_coords['station'] = calls
1528
1529        new_dims = da.dims[:-2] + ('station',)
1530
1531        # make grid def in from section3 on da attrs
1532        grid_def_in = self.griddef()
1533
1534        if da.chunks is None:
1535            data = interp_nd_stations(da.data, method=method, grid_def_in=grid_def_in, lats=lats,
1536                                      lons=lons, method_options=method_options, num_threads=num_threads)
1537        else:
1538            import dask
1539            front_shape = da.shape[:-1]
1540            data = da.data.map_blocks(interp_nd_stations, method=method, grid_def_in=grid_def_in,
1541                                      lats=lats, lons=lons, method_options=method_options,
1542                                      drop_axis=-1, chunks=da.chunks[:-2]+latitude.shape,
1543                                      dtype=da.dtype)
1544
1545        new_da = xr.DataArray(data, dims=new_dims, coords=new_coords, attrs=da.attrs)
1546
1547        new_da.name = da.name
1548        return new_da

Perform spatial interpolation to station points.

Parameters
  • method: Interpolate method to use. This can either be an integer or string using the following mapping:
Interpolate Scheme Integer Value
'bilinear' 0
'bicubic' 1
'neighbor' 2
'budget' 3
'spectral' 4
'neighbor-budget' 6

  • calls: Station calls used for labeling new station index coordinate
  • lats: Latitudes of the station points.
  • lons: Longitudes of the station points.
Returns
  • interp_to_stations: DataArray interpolated to lat and lon locations and labeled with dimension and coordinate 'station'. (..., y, x) -> (..., station)
def to_grib2(self, filename, mode: Literal['x', 'w', 'a'] = 'x'):
1550    def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
1551        """
1552        Write a DataArray to a grib2 file.
1553
1554        Parameters
1555        ----------
1556        filename
1557            Name of the grib2 file to write to.
1558        mode: {"x", "w", "a"}, optional, default="x"
1559            Persistence mode
1560
1561            +------+-----------------------------------+
1562            | mode | Description                       |
1563            +======+===================================+
1564            | x    | create (fail if exists)           |
1565            +------+-----------------------------------+
1566            | w    | create (overwrite if exists)      |
1567            +------+-----------------------------------+
1568            | a    | append (create if does not exist) |
1569            +------+-----------------------------------+
1570
1571        """
1572        da = self._obj.copy(deep=True)
1573
1574        coords_keys = sorted(da.coords.keys())
1575        coords_keys = [k for k in coords_keys if k in AVAILABLE_NON_GEO_COORDS]
1576
1577        # If there are dimension coordinates, the DataArray is a hypercube of
1578        # grib2 messages.
1579
1580        # Create `indexes` which is a list of lists of dictionaries for all
1581        # dimension coordinates. Each dictionary key is the dimension
1582        # coordinate name and the value is a list of the dimension coordinate
1583        # values.  This allows for easy iteration over all possible grib2
1584        # messages in the DataArray by using itertools.product.
1585        #
1586        # For example:
1587        # indexes = [
1588        #     [
1589        #         {"leadTime": 9},
1590        #         {"leadTime": 12},
1591        #     ],
1592        #     [
1593        #         {"valueOfFirstFixedSurface": 900},
1594        #         {"valueOfFirstFixedSurface": 925},
1595        #         {"valueOfFirstFixedSurface": 950},
1596        #     ],
1597        # ]
1598
1599        # assign loc indexes to dimensions without indexes for uniform selection by name
1600        loc_indexes = list()
1601        for dim in da.dims:
1602            if dim not in da.indexes:
1603                da = da.assign_coords({dim: range(da[dim].size)})
1604                loc_indexes.append(dim)
1605
1606        indexes = []
1607        for index in [i for i in AVAILABLE_NON_GEO_DIMS if i in da.dims]:
1608            values = da.coords[index].values
1609            if len(values) != len(set(values)):
1610                raise ValueError(
1611                    f"Dimension coordinate '{index}' has duplicate values, but to_grib2 requires unique values to find each GRIB2 message in the DataArray."
1612                )
1613            listeach = [{index: value} for value in sorted(values)]
1614            indexes.append(listeach)
1615
1616        # If `dim_coords` is [], then the DataArray is a single grib2 message and
1617        # itertools.product(*dim_coords) will run once with `selectors = ()`.
1618        for selectors in itertools.product(*indexes):
1619            # Need to find the correct data in the DataArray based on the
1620            # dimension coordinates.
1621            filters = {k: v for d in selectors for k, v in d.items()}
1622
1623            # If `filters` is {}, then the DataArray is a single grib2 message
1624            # and da.sel(indexers={}) returns the DataArray.
1625            selected = da.sel(indexers=filters)
1626
1627            newmsg = Grib2Message(
1628                selected.attrs["GRIB2IO_section0"],
1629                selected.attrs["GRIB2IO_section1"],
1630                selected.attrs["GRIB2IO_section2"],
1631                selected.attrs["GRIB2IO_section3"],
1632                selected.attrs["GRIB2IO_section4"],
1633                selected.attrs["GRIB2IO_section5"],
1634            )
1635            newmsg.data = np.array(selected.data)
1636
1637            # For dimension coordinates, set the grib2 message metadata to the
1638            # dimension coordinate value.
1639            for index, value in filters.items():
1640                if index not in loc_indexes:
1641                    setattr(newmsg, index, value)
1642
1643            # For non-dimension coordinates, set the grib2 message metadata to
1644            # the DataArray coordinate value.
1645            for index in [i for i in coords_keys if i not in da.dims]:
1646                setattr(newmsg, index, selected.coords[index].values)
1647
1648            # Set section 5 attributes to the da.encoding dictionary.
1649            for key, value in selected.encoding.items():
1650                if key in ["dtype", "chunks", "original_shape"]:
1651                    continue
1652                setattr(newmsg, key, value)
1653
1654            # write the message to file
1655            with grib2io.open(filename, mode=mode) as f:
1656                f.write(newmsg)
1657            mode = "a"

Write a DataArray to a grib2 file.

Parameters
  • filename: Name of the grib2 file to write to.
  • mode ({"x", "w", "a"}, optional, default="x"): Persistence mode

    +------+-----------------------------------+ | mode | Description | +======+===================================+ | x | create (fail if exists) | +------+-----------------------------------+ | w | create (overwrite if exists) | +------+-----------------------------------+ | a | append (create if does not exist) | +------+-----------------------------------+

def update_attrs(self, **kwargs):
1659    def update_attrs(self, **kwargs):
1660        """
1661        Update many of the attributes of the DataArray.
1662
1663        Parameters
1664        ----------
1665        **kwargs
1666            Attributes to update.  This can include many of the GRIB2IO message
1667            attributes that you can find when you print a GRIB2IO message. For
1668            conflicting updates, the last keyword will be used.
1669
1670            +-----------------------+------------------------------------------+
1671            | kwargs                | Description                              |
1672            +=======================+==========================================+
1673            | shortName="VTMP"      | Set shortName to "VTMP", along with      |
1674            |                       | appropriate discipline,                  |
1675            |                       | parameterCategory, parameterNumber,      |
1676            |                       | fullName and units.                      |
1677            +-----------------------+------------------------------------------+
1678            | discipline=0,         | Set shortName, discipline,               |
1679            | parameterCategory=0,  | parameterCategory, parameterNumber,      |
1680            | parameterNumber=1     | fullName and units appropriate for       |
1681            |                       | "Virtual Temperature".                   |
1682            +-----------------------+------------------------------------------+
1683            | discipline=0,         | Conflicting keywords but                 |
1684            | parameterCategory=0,  | 'shortName="TMP"' wins.  Set shortName,  |
1685            | parameterNumber=1,    | discipline, parameterCategory,           |
1686            | shortName="TMP"       | parameterNumber, fullName and units      |
1687            |                       | appropriate for "Temperature".           |
1688            +-----------------------+------------------------------------------+
1689
1690        Returns
1691        -------
1692        DataArray
1693            DataArray with updated attributes.
1694        """
1695        da = self._obj.copy(deep=True)
1696
1697        newmsg = Grib2Message(
1698            da.attrs["GRIB2IO_section0"],
1699            da.attrs["GRIB2IO_section1"],
1700            da.attrs["GRIB2IO_section2"],
1701            da.attrs["GRIB2IO_section3"],
1702            da.attrs["GRIB2IO_section4"],
1703            da.attrs["GRIB2IO_section5"],
1704        )
1705
1706        coords_keys = [
1707            k
1708            for k in da.coords.keys()
1709            if k in AVAILABLE_NON_GEO_COORDS
1710        ]
1711
1712        for grib2_name, value in kwargs.items():
1713            if grib2_name == "gridDefinitionTemplateNumber":
1714                raise ValueError(
1715                    "The gridDefinitionTemplateNumber attribute cannot be updated.  The best way to change to a different grid is to interpolate the data to a new grid using the grib2io interpolate functions."
1716                )
1717            if grib2_name == "productDefinitionTemplateNumber":
1718                raise ValueError(
1719                    "The productDefinitionTemplateNumber attribute cannot be updated."
1720                )
1721            if grib2_name == "dataRepresentationTemplateNumber":
1722                raise ValueError(
1723                    "The dataRepresentationTemplateNumber attribute cannot be updated."
1724                )
1725            if grib2_name in coords_keys:
1726                warnings.warn(
1727                    f"Skipping attribute '{grib2_name}' because it is a coordinate. Use da.assign_coords() to change coordinate values."
1728                )
1729                continue
1730            if hasattr(newmsg, grib2_name):
1731                setattr(newmsg, grib2_name, value)
1732            else:
1733                warnings.warn(
1734                    f"Skipping attribute '{grib2_name}' because it is not a valid GRIB2 attribute for this message and cannot be updated."
1735                )
1736                continue
1737
1738        da.attrs["GRIB2IO_section0"] = newmsg.section0
1739        da.attrs["GRIB2IO_section1"] = newmsg.section1
1740        da.attrs["GRIB2IO_section2"] = newmsg.section2 or []
1741        da.attrs["GRIB2IO_section3"] = newmsg.section3
1742        da.attrs["GRIB2IO_section4"] = newmsg.section4
1743        da.attrs["GRIB2IO_section5"] = newmsg.section5
1744        da.attrs["fullName"] = newmsg.fullName
1745        da.attrs["shortName"] = newmsg.shortName
1746        da.attrs["units"] = newmsg.units
1747
1748        return da

Update many of the attributes of the DataArray.

Parameters
  • **kwargs: Attributes to update. This can include many of the GRIB2IO message attributes that you can find when you print a GRIB2IO message. For conflicting updates, the last keyword will be used.

+-----------------------+------------------------------------------+ | kwargs | Description | +=======================+==========================================+ | shortName="VTMP" | Set shortName to "VTMP", along with | | | appropriate discipline, | | | parameterCategory, parameterNumber, | | | fullName and units. | +-----------------------+------------------------------------------+ | discipline=0, | Set shortName, discipline, | | parameterCategory=0, | parameterCategory, parameterNumber, | | parameterNumber=1 | fullName and units appropriate for | | | "Virtual Temperature". | +-----------------------+------------------------------------------+ | discipline=0, | Conflicting keywords but | | parameterCategory=0, | 'shortName="TMP"' wins. Set shortName, | | parameterNumber=1, | discipline, parameterCategory, | | shortName="TMP" | parameterNumber, fullName and units | | | appropriate for "Temperature". | +-----------------------+------------------------------------------+

Returns
  • DataArray: DataArray with updated attributes.
def subset(self, lats, lons) -> xarray.core.dataarray.DataArray:
1750    def subset(self, lats, lons) -> xr.DataArray:
1751        """
1752        Subset the DataArray to a region defined by latitudes and longitudes.
1753
1754        Parameters
1755        ----------
1756        lats
1757            Latitude bounds of the region.
1758        lons
1759            Longitude bounds of the region.
1760
1761        Returns
1762        -------
1763        subset
1764            DataArray subset to the region.
1765        """
1766        da = self._obj.copy(deep=True)
1767
1768        newmsg = Grib2Message(
1769            da.attrs["GRIB2IO_section0"],
1770            da.attrs["GRIB2IO_section1"],
1771            da.attrs["GRIB2IO_section2"],
1772            da.attrs["GRIB2IO_section3"],
1773            da.attrs["GRIB2IO_section4"],
1774            da.attrs["GRIB2IO_section5"],
1775        )
1776
1777        newmsg.data = np.zeros((newmsg.ny, newmsg.nx), dtype=np.float32)
1778
1779        newmsg = newmsg.subset(lats, lons)
1780
1781        da.attrs["GRIB2IO_section3"] = newmsg.section3
1782
1783        mask_lat = (da.latitude >= newmsg.latitudeLastGridpoint) & (
1784            da.latitude <= newmsg.latitudeFirstGridpoint
1785        )
1786        mask_lon = (da.longitude >= newmsg.longitudeFirstGridpoint) & (
1787            da.longitude <= newmsg.longitudeLastGridpoint
1788        )
1789
1790        del newmsg
1791
1792        return da.where((mask_lon & mask_lat).compute(), drop=True)

Subset the DataArray to a region defined by latitudes and longitudes.

Parameters
  • lats: Latitude bounds of the region.
  • lons: Longitude bounds of the region.
Returns
  • subset: DataArray subset to the region.
def build_datatree_from_grib(filename, file_index, filters=None, stack_vertical=False):
1795def build_datatree_from_grib(filename, file_index, filters=None, stack_vertical=False):
1796    """
1797    Build a DataTree from GRIB2 messages.
1798
1799    Parameters
1800    ----------
1801    filename : str
1802        Path to the GRIB2 file.
1803    file_index : pandas.DataFrame
1804        DataFrame of GRIB2 message index.
1805    filters : dict, optional
1806        Filter criteria for GRIB2 messages.
1807    stack_vertical : bool, optional
1808        If True, vertical levels will be stacked in a single dataset
1809        instead of being organized in separate tree nodes.
1810
1811    Returns
1812    -------
1813    xarray.DataTree
1814        A hierarchical DataTree representation of the GRIB2 data.
1815    """
1816    if filters is None:
1817        filters = {}
1818
1819    # Apply any filters from user
1820    for k, v in filters.items():
1821        if k not in file_index.columns:
1822            file_index = file_index.copy()
1823            file_index[k] = file_index.msg.apply(lambda msg: getattr(msg, k, None))
1824        file_index = filter_index(file_index, k, v)
1825
1826    # Make a copy to avoid the SettingWithCopyWarning
1827    file_index = file_index.copy()
1828
1829    # Extract metadata needed for tree organization
1830    # Use a safer approach to handle missing attributes
1831    def safe_getattr(obj, name):
1832        try:
1833            attr = getattr(obj, name)
1834            # Need to test if the attribute is Grib2Metadata. If so,
1835            # then get the value attribute.
1836            if isinstance(attr, grib2io.templates.Grib2Metadata):
1837                attr = attr.value
1838            return attr
1839        except (AttributeError, KeyError):
1840            return None
1841
1842    for attr in _TREE_HIERARCHY_LEVELS:
1843        if (attr not in file_index.columns) and (attr != 'valueOfFirstFixedSurface'):
1844            file_index[attr] = file_index.msg.apply(lambda msg: safe_getattr(msg, attr))
1845
1846    # Also extract shortName for variable naming
1847    if 'shortName' not in file_index.columns:
1848        file_index = file_index.assign(shortName=file_index.msg.apply(lambda msg: getattr(msg, 'shortName', None)))
1849        file_index = file_index.assign(nx=file_index.msg.apply(lambda msg: getattr(msg, 'nx', None)))
1850        file_index = file_index.assign(ny=file_index.msg.apply(lambda msg: getattr(msg, 'ny', None)))
1851
1852    # Create root DataTree
1853    root = xr.DataTree()
1854
1855    # Adjust hierarchy levels if we're stacking vertical levels
1856    hierarchy_levels = list(_TREE_HIERARCHY_LEVELS) # This makes a copy
1857    if stack_vertical and "valueOfFirstFixedSurface" in hierarchy_levels:
1858        hierarchy_levels.remove("valueOfFirstFixedSurface")
1859
1860    # First group by level type
1861    level_groups = {}
1862
1863    # Create a dictionary to group data by level type
1864    for level_type in file_index['typeOfFirstFixedSurface'].unique():
1865        if pd.notna(level_type):  # Skip None/NaN values
1866            level_info = _LEVEL_NAME_MAPPING.get(level_type, f"level_{level_type}")
1867            level_name = level_info[0]
1868            level_source = level_info[1]
1869            # Get all rows for this level type
1870            level_data = file_index[file_index['typeOfFirstFixedSurface'] == level_type]
1871            level_groups[level_type] = {'name': level_name, 'data': level_data}
1872
1873    # Process each level group
1874    for level_type, group_info in level_groups.items():
1875        level_name = group_info['name']
1876        level_df = group_info['data']
1877
1878        # Create a branch for this level type
1879        level_tree = xr.DataTree()
1880
1881        # Process this branch based on PDTN, perturbation number, etc.
1882        process_level_branch(level_tree, level_df, filename)
1883
1884        # Add this branch to the main tree
1885        root[level_name] = level_tree
1886
1887    return root

Build a DataTree from GRIB2 messages.

Parameters
  • filename (str): Path to the GRIB2 file.
  • file_index (pandas.DataFrame): DataFrame of GRIB2 message index.
  • filters (dict, optional): Filter criteria for GRIB2 messages.
  • stack_vertical (bool, optional): If True, vertical levels will be stacked in a single dataset instead of being organized in separate tree nodes.
Returns
  • xarray.DataTree: A hierarchical DataTree representation of the GRIB2 data.
def process_level_branch(level_tree, df, filename):
1890def process_level_branch(level_tree, df, filename):
1891    """
1892    Process a level type branch of the data tree, organizing by PDTN and other attributes.
1893
1894    Parameters
1895    ----------
1896    level_tree : xarray.DataTree
1897        The DataTree node for this level type
1898    df : pandas.DataFrame
1899        DataFrame of messages for this level type
1900    filename : str
1901        Path to the GRIB2 file
1902    """
1903    # Group by PDTN
1904    pdtn_groups = {}
1905
1906    # Group data by PDTN first
1907    for pdtn_value in df['productDefinitionTemplateNumber'].unique():
1908        if pd.notna(pdtn_value):
1909            pdtn_df = df[df['productDefinitionTemplateNumber'] == pdtn_value]
1910            pdtn_groups[pdtn_value] = pdtn_df
1911
1912    # If there's only one PDTN value, skip creating PDTN branch level
1913    if len(pdtn_groups) == 1:
1914        pdtn, pdtn_df = next(iter(pdtn_groups.items()))
1915
1916        pdtn_name = f"pdtn_{int(pdtn)}"
1917
1918        # Check if we need to further subdivide by perturbation number
1919        has_perturbations = ('perturbationNumber' in pdtn_df.columns and
1920                             len(pdtn_df['perturbationNumber'].dropna().unique()) > 1)
1921
1922        # Check if we need to further subdivide by probabilities unique for each variable.
1923        has_probabilities = ('typeOfProbability' in pdtn_df.columns and
1924                             len(pdtn_df['typeOfProbability'].dropna().unique()) > 1)
1925
1926        if has_perturbations:
1927            # Process perturbations directly on the level tree
1928            process_perturbation_groups(level_tree, pdtn_df, filename)
1929        elif has_probabilities:
1930            # Process probability groups
1931            process_probability_groups(level_tree, pdtn_df, filename)
1932        else:
1933            # Try to create dataset directly on level
1934            try:
1935                dss = create_datasets_from_df(pdtn_df, filename)
1936                if dss is not None:
1937                    dt = xr.DataTree()
1938                    if len(dss) == 1:
1939                        dt.ds = dss[0]
1940                    else:
1941                        for ds in dss:
1942                            varname = list(ds.data_vars)[0]
1943                            dt[f"var_{varname}"] = ds
1944                    level_tree[pdtn_name] = dt
1945            except Exception as e:
1946                print(f"Error creating dataset for level with pdtn {int(pdtn)}: {e}")
1947
1948                # Try to separate by variable name as a fallback
1949                try_process_by_variables(level_tree, pdtn_df, filename)
1950    else:
1951        # Multiple PDTN values, process each group with PDTN branch nodes
1952        for pdtn, pdtn_df in pdtn_groups.items():
1953            # Use a simple node name that's easy to use in code
1954            pdtn_name = f"pdtn_{int(pdtn)}"
1955
1956            # Check if we need to further subdivide by perturbation number
1957            has_perturbations = ('perturbationNumber' in pdtn_df.columns and
1958                                 len(pdtn_df['perturbationNumber'].dropna().unique()) > 1)
1959
1960            # Check if we need to further subdivide by probabilities unique for each variable.
1961            has_probabilities = ('typeOfProbability' in pdtn_df.columns and
1962                                 len(pdtn_df['typeOfProbability'].dropna().unique()) > 1)
1963
1964            if has_perturbations:
1965                # Create a branch for this PDTN
1966                pdtn_tree = xr.DataTree()
1967
1968                # Process perturbation groups
1969                process_perturbation_groups(pdtn_tree, pdtn_df, filename)
1970
1971                # Only add the PDTN branch if it has children
1972                if len(pdtn_tree.children) > 0 or pdtn_tree.ds is not None:
1973                    level_tree[pdtn_name] = pdtn_tree
1974            elif has_probabilities:
1975                # Create a branch for this PDTN
1976                pdtn_tree = xr.DataTree()
1977
1978                # Process probability groups
1979                process_probability_groups(pdtn_tree, pdtn_df, filename)
1980
1981                # Only add the PDTN branch if it has children
1982                if len(pdtn_tree.children) > 0 or pdtn_tree.ds is not None:
1983                    level_tree[pdtn_name] = pdtn_tree
1984            else:
1985                # Create a subtree for this PDTN
1986                pdtn_tree = xr.DataTree()
1987
1988                # Try to create dataset directly on level
1989                try:
1990                    dss = create_datasets_from_df(pdtn_df, filename)
1991                    if dss is not None:
1992                        if len(dss) == 1:
1993                            pdtn_tree.ds = dss[0]
1994                        else:
1995                            for ds in dss:
1996                                varname = list(ds.data_vars)[0]
1997                                pdtn_tree[f"var_{varname}"] = ds
1998                        level_tree[pdtn_name] = pdtn_tree
1999                except Exception as e:
2000                    print(f"Error creating dataset for level with pdtn {int(pdtn)}: {e}")
2001
2002                    # Try to separate by variable name as a fallback
2003                    try_process_by_variables(pdtn_tree, pdtn_df, filename)
2004                    level_tree[pdtn_name] = pdtn_tree

Process a level type branch of the data tree, organizing by PDTN and other attributes.

Parameters
  • level_tree (xarray.DataTree): The DataTree node for this level type
  • df (pandas.DataFrame): DataFrame of messages for this level type
  • filename (str): Path to the GRIB2 file
def process_probability_groups(target_tree, pdtn_df, filename):
2007def process_probability_groups(target_tree, pdtn_df, filename):
2008    """
2009    """
2010    success = False
2011    # Group by type of probability
2012    prob_groups = {}
2013    for prob_value in pdtn_df['typeOfProbability'].unique():
2014        if pd.notna(prob_value):
2015            prob_df = pdtn_df[pdtn_df['typeOfProbability'] == prob_value]
2016            prob_groups[prob_value] = prob_df
2017
2018    # Process each probability group
2019    prob_dict = {}
2020    for prob_num, prob_df in prob_groups.items():
2021        prob_name = f"prob_{int(prob_num)}"
2022
2023        # Try to create dataset for this probability group
2024        try:
2025            dss = create_datasets_from_df(prob_df, filename)
2026            dt = xr.DataTree()
2027            if len(dss) == 1:
2028                dt.ds = dss[0]
2029                target_tree[prob_name] = dt
2030            elif len(dss) > 1:
2031                for ds in dss:
2032                    dt[f"var_{ds.data_vars[0]}"] = ds
2033            target_tree[prob_name] = dt
2034        except Exception as e:
2035            # Log error but continue processing other groups
2036            print(f"Error creating dataset for type of probability {prob_name}: {e}")
2037
2038    return success
def process_perturbation_groups(target_tree, pdtn_df, filename):
2041def process_perturbation_groups(target_tree, pdtn_df, filename):
2042    """
2043    Process perturbation groups and add them to the target tree.
2044
2045    Parameters
2046    ----------
2047    target_tree : xarray.DataTree
2048        The tree node to add perturbation groups to
2049    pdtn_df : pandas.DataFrame
2050        DataFrame of messages for a specific PDTN
2051    filename : str
2052        Path to the GRIB2 file
2053
2054    Returns
2055    -------
2056    bool
2057        True if at least one perturbation was successfully processed
2058    """
2059    success = False
2060    # Group by perturbation number
2061    pert_groups = {}
2062    for pert_value in pdtn_df['perturbationNumber'].unique():
2063        if pd.notna(pert_value):
2064            pert_df = pdtn_df[pdtn_df['perturbationNumber'] == pert_value]
2065            pert_groups[pert_value] = pert_df
2066
2067    # Process each perturbation group
2068    for pert_num, pert_df in pert_groups.items():
2069        pert_name = f"pert_{int(pert_num)}"
2070
2071        ## Try to create dataset for this perturbation group
2072        #try:
2073        #    dss = create_datasets_from_df(pert_df, filename)
2074        #    if dss is not None:
2075        #        if len(dss) == 1:
2076        #            target_tree.ds = dss[0]
2077        #        else:
2078        #            dss_dict = {f"ds_{i}": ds for i, ds in enumerate(dss)}
2079        #            atree = xr.DataTree(dss_dict)
2080        #            target_tree[prob_name] = atree
2081        #        success = True
2082        #except Exception as e:
2083        #    # Log error but continue processing other groups
2084        #    print(f"Error creating dataset for perturbation {pert_name}: {e}")
2085
2086        # Try to create dataset for this perturbation group
2087        try:
2088            dss = create_datasets_from_df(pert_df, filename)
2089            dt = xr.DataTree()
2090            if len(dss) == 1:
2091                dt.ds = dss[0]
2092                target_tree[pert_name] = dt
2093            elif len(dss) > 1:
2094                for ds in dss:
2095                    dt[f"pert{ds.data_vars[0]}"] = ds
2096            target_tree[pert_name] = dt
2097        except Exception as e:
2098            # Log error but continue processing other groups
2099            print(f"Error creating dataset for perturbation {pert_name}: {e}")
2100
2101    return success

Process perturbation groups and add them to the target tree.

Parameters
  • target_tree (xarray.DataTree): The tree node to add perturbation groups to
  • pdtn_df (pandas.DataFrame): DataFrame of messages for a specific PDTN
  • filename (str): Path to the GRIB2 file
Returns
  • bool: True if at least one perturbation was successfully processed
def try_process_by_variables(target_tree, df, filename):
2104def try_process_by_variables(target_tree, df, filename):
2105    """
2106    Try to separate data by variable names and create datasets.
2107
2108    Parameters
2109    ----------
2110    target_tree : xarray.DataTree
2111        The tree node to add variable datasets to
2112    df : pandas.DataFrame
2113        DataFrame of messages
2114    filename : str
2115        Path to the GRIB2 file
2116
2117    Returns
2118    -------
2119    bool
2120        True if at least one variable was successfully processed
2121    """
2122    success = False
2123
2124    try:
2125        for var_name in df['shortName'].unique():
2126            if pd.notna(var_name):
2127                var_df = df[df['shortName'] == var_name]
2128                try:
2129                    var_ds = create_datasets_from_df(var_df, filename)
2130                    if var_ds is not None:
2131                        target_tree[f"var_{var_name}"] = var_ds[0]
2132                        success = True
2133                except Exception as var_e:
2134                    print(f"Error creating dataset for variable {var_name}: {var_e}")
2135    except Exception as nested_e:
2136        print(f"Failed to process variables: {nested_e}")
2137
2138    return success

Try to separate data by variable names and create datasets.

Parameters
  • target_tree (xarray.DataTree): The tree node to add variable datasets to
  • df (pandas.DataFrame): DataFrame of messages
  • filename (str): Path to the GRIB2 file
Returns
  • bool: True if at least one variable was successfully processed
def create_datasets_from_df( df, filename, verbose=False) -> Optional[List[xarray.core.dataset.Dataset]]:
2141def create_datasets_from_df(
2142    df,
2143    filename,
2144    verbose=False
2145) -> typing.Optional[typing.List[xr.Dataset]]:
2146    """
2147    Create a list of xarray Datasets from a DataFrame of messages.
2148
2149    Parameters
2150    ----------
2151    df : pandas.DataFrame
2152        DataFrame of GRIB messages
2153    filename : str
2154        Path to the GRIB2 file
2155    verbose : bool, optional
2156        If True, prints detailed debugging information
2157
2158    Returns
2159    -------
2160    dss
2161        List of Datasets, or None if creation failed
2162    """
2163    try:
2164        if verbose:
2165            print(f"\n==== VERBOSE DEBUG INFO ====")
2166            print(f"Creating dataset from DataFrame with {len(df)} messages")
2167            print(f"DataFrame columns: {df.columns.tolist()}")
2168
2169            if 'shortName' in df.columns:
2170                print(f"Variables in group: {df['shortName'].unique().tolist()}")
2171
2172            if 'valueOfFirstFixedSurface' in df.columns:
2173                print(f"Vertical levels: {df['valueOfFirstFixedSurface'].unique().tolist()}")
2174
2175        # Process by variables
2176        datasets = {}
2177
2178        # Process each variable separately, regardless of whether there are vertical levels
2179        for var_name, var_df in df.groupby('shortName'):
2180            if verbose:
2181                print(
2182                    f"\n  Processing variable: {var_name} with {len(var_df)} messages, with pdtn(s) = {var_df['productDefinitionTemplateNumber'].unique()}")
2183
2184            # Process vertical levels if present
2185            if 'valueOfFirstFixedSurface' in var_df.columns and len(var_df['valueOfFirstFixedSurface'].unique()) > 1:
2186                if verbose:
2187                    print(f"  Variable {var_name} has multiple vertical levels")
2188                # Process each level separately
2189                level_das = []
2190
2191                for level, level_df in var_df.groupby('valueOfFirstFixedSurface'):
2192                    if verbose:
2193                        print(f"    Processing level {level} with {len(level_df)} messages")
2194                    try:
2195                        # Parse the index and get dimensions for this level
2196                        file_index, non_geo_dims, attrs, coord_attrs = parse_grib_index(level_df, {})
2197                        # Remove valueOfFirstFixedSurface from dimensions since we're handling it separately
2198                        non_geo_dims = [d for d in non_geo_dims if d.__name__ != "ValueOfFirstFixedSurfaceDim"]
2199
2200                        frames, cube, extra_geo = make_variables(
2201                            file_index, filename, non_geo_dims, allow_uneven_dims=True)
2202
2203                        if frames is not None and len(frames) == 1:
2204                            level_da = build_da_without_coords(frames[0], cube, filename, attrs)
2205                            # Add this level to the list with its level value as coord
2206                            level_da = level_da.assign_coords(valueOfFirstFixedSurface=level)
2207                            level_das.append(level_da)
2208                    except Exception as e:
2209                        if verbose:
2210                            print(f"    Error processing level {level} for {var_name}: {e}")
2211
2212                if level_das:
2213                    # Combine all levels into a single DataArray along the valueOfFirstFixedSurface dimension
2214                    if verbose:
2215                        print(f"    Combining {len(level_das)} levels for {var_name}")
2216                    try:
2217                        combined_da = xr.concat(level_das, dim='valueOfFirstFixedSurface')
2218                        # Create a simple dataset with just this variable
2219                        var_ds = xr.Dataset({var_name: combined_da})
2220                        # Assign the coords from the first level's cube
2221                        var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs)
2222                       # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords?
2223                       # var_ds = var_ds.assign_coords(coords_from_cube(cube))
2224                       # Add extra geo coords
2225                       # if extra_geo:
2226                       #    var_ds = var_ds.assign_coords(extra_geo)
2227                       # Add valid date coords if available
2228                       # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords:
2229                       #    var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime']))
2230
2231                        # Store this variable's dataset
2232                        datasets[var_name] = var_ds
2233                        if verbose:
2234                            print(f"    Created dataset for {var_name} with levels")
2235                    except Exception as e:
2236                        if verbose:
2237                            print(f"    Error combining levels for {var_name}: {e}")
2238            else:
2239                # Single level or no vertical levels
2240                if verbose:
2241                    print(f"  Variable {var_name} is a single level or has no vertical dimension")
2242                try:
2243                    # Parse the index and get dimensions
2244                    file_index, non_geo_dims, attrs, coord_attrs = parse_grib_index(var_df, {})
2245                    frames, cube, extra_geo = make_variables(file_index, filename, non_geo_dims, allow_uneven_dims=True)
2246
2247                    if frames is not None and len(frames) == 1:
2248                        # Create dataset with this variable
2249                        var_ds = xr.Dataset()
2250                        da = build_da_without_coords(frames[0], cube, filename, attrs)
2251                        var_ds[da.name] = da
2252
2253                        # Assign coords
2254                        var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs)
2255                       # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords?
2256                       # var_ds = var_ds.assign_coords(coords_from_cube(cube))
2257                       # if extra_geo:
2258                       #    var_ds = var_ds.assign_coords(extra_geo)
2259                       # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords:
2260                       #    var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime']))
2261
2262                        # Store this variable's dataset
2263                        datasets[var_name] = var_ds
2264                        if verbose:
2265                            print(f"  Created dataset for {var_name}")
2266                    elif frames is not None and len(frames) > 1:
2267                        if verbose:
2268                            print(f"  Variable {var_name} has multiple frames, possibly different parameters")
2269                        # Just use the first frame for now (simplified approach)
2270                        var_ds = xr.Dataset()
2271                        da = build_da_without_coords(frames[0], cube, filename, attrs)
2272                        var_ds[da.name] = da
2273
2274                        # Assign coords
2275                        var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs)
2276                       # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords?
2277                       # var_ds = var_ds.assign_coords(coords_from_cube(cube))
2278                       # if extra_geo:
2279                       #    var_ds = var_ds.assign_coords(extra_geo)
2280                       # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords:
2281                       #    var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime']))
2282
2283                        datasets[var_name] = var_ds
2284                        if verbose:
2285                            print(f"  Created dataset with first frame for {var_name}")
2286                except Exception as e:
2287                    if verbose:
2288                        print(f"  Error processing variable {var_name}: {e}")
2289
2290        # Attempt to merge all the variable datasets
2291        if datasets:
2292            try:
2293                if verbose:
2294                    print(f"\nMerging {len(datasets)} datasets...")
2295                # Get the list of datasets to merge
2296                ds_list = list(datasets.values())
2297
2298                # Try merging them all at once
2299                try:
2300                    combined_ds = xr.merge(ds_list)
2301                    if verbose:
2302                        print(f"Successfully merged all datasets into one.")
2303                        print(f"Final dataset has variables: {list(combined_ds.data_vars)}")
2304                        print(f"==== END VERBOSE DEBUG INFO ====\n")
2305                    return [combined_ds]
2306                except Exception as merge_error:
2307                    if verbose:
2308                        print(f"Error merging all datasets: {merge_error}")
2309                    return ds_list
2310            except Exception as e:
2311                if verbose:
2312                    print(f"Error in final merge process: {e}")
2313                    print(f"==== END VERBOSE DEBUG INFO ====\n")
2314                return None
2315        else:
2316            if verbose:
2317                print(f"No datasets were created for any variables")
2318                print(f"==== END VERBOSE DEBUG INFO ====\n")
2319            return None
2320
2321    except Exception as e:
2322        # If there's an error, log it and return None
2323        if verbose:
2324            print(f"Error creating dataset: {e}")
2325            import traceback
2326            traceback.print_exc()
2327            print(f"==== END VERBOSE DEBUG INFO ====\n")
2328        return None

Create a list of xarray Datasets from a DataFrame of messages.

Parameters
  • df (pandas.DataFrame): DataFrame of GRIB messages
  • filename (str): Path to the GRIB2 file
  • verbose (bool, optional): If True, prints detailed debugging information
Returns
  • dss: List of Datasets, or None if creation failed
@xr.register_datatree_accessor('grib2io')
class Grib2ioDataTree:
2333    @xr.register_datatree_accessor("grib2io")
2334    class Grib2ioDataTree:
2335        """
2336        DataTree accessor for GRIB2 files.
2337
2338        This accessor provides methods for working with GRIB2 data organized
2339        in a hierarchical tree structure.
2340        """
2341
2342        def __init__(self, datatree_obj):
2343            self._obj = datatree_obj
2344
2345        def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
2346            """
2347            Write all datasets in the DataTree to a GRIB2 file.
2348
2349            Parameters
2350            ----------
2351            filename : str
2352                Name of the GRIB2 file to write to.
2353            mode : {"x", "w", "a"}, optional
2354                Persistence mode, default is "x" (create, fail if exists)
2355            """
2356            # Start with the specified mode
2357            current_mode = mode
2358
2359            # Function to recursively process the tree
2360            def process_tree(node):
2361                nonlocal current_mode
2362
2363                # If this is a Dataset node with data variables
2364                if node.ds is not None and node.ds.data_vars:
2365                    # Write dataset to GRIB2 file
2366                    node.ds.grib2io.to_grib2(filename, mode=current_mode)
2367                    # Switch to append mode after first write
2368                    current_mode = "a"
2369
2370                # Process children
2371                for child_name, child_node in node.children.items():
2372                    process_tree(child_node)
2373
2374            # Start processing from the root
2375            process_tree(self._obj)
2376
2377        def griddef(self):
2378            """
2379            Get the grid definition from the first dataset in the tree that has one.
2380
2381            Returns
2382            -------
2383            grib2io.Grib2GridDef
2384                Grid definition object
2385            """
2386            # Function to find first dataset with GRIB2IO_section3
2387            def find_griddef(node):
2388                if node.ds is not None and node.ds.data_vars:
2389                    for var_name in node.ds.data_vars:
2390                        if 'GRIB2IO_section3' in node.ds[var_name].attrs:
2391                            return Grib2GridDef.from_section3(node.ds[var_name].attrs['GRIB2IO_section3'])
2392
2393                # Check children
2394                for child_name, child_node in node.children.items():
2395                    griddef = find_griddef(child_node)
2396                    if griddef is not None:
2397                        return griddef
2398
2399                return None
2400
2401            return find_griddef(self._obj)
2402
2403        def interp(self, method, grid_def_out, method_options=None, num_threads=1):
2404            """
2405            Interpolate all datasets in the tree to a new grid.
2406
2407            Parameters
2408            ----------
2409            method : str or int
2410                Interpolation method to use
2411            grid_def_out : grib2io.Grib2GridDef
2412                Target grid definition
2413            method_options : list, optional
2414                Options for interpolation method
2415            num_threads : int, optional
2416                Number of threads to use for interpolation
2417
2418            Returns
2419            -------
2420            xarray.DataTree
2421                New DataTree with interpolated data
2422            """
2423            new_tree = xr.DataTree()
2424
2425            # Function to recursively process the tree
2426            def process_tree(node, new_parent):
2427                # If this is a Dataset node with data variables
2428                if node.ds is not None and node.ds.data_vars:
2429                    # Interpolate dataset
2430                    interp_ds = node.ds.grib2io.interp(method, grid_def_out,
2431                                                       method_options=method_options,
2432                                                       num_threads=num_threads)
2433
2434                    # Add to new tree at the same path
2435                    if node == self._obj:  # Root node
2436                        new_parent.ds = interp_ds
2437                    else:
2438                        new_parent.ds = interp_ds
2439
2440                # Process children
2441                for child_name, child_node in node.children.items():
2442                    # Create same child in new tree
2443                    new_child = xr.DataTree()
2444                    new_parent[child_name] = new_child
2445                    process_tree(child_node, new_child)
2446
2447            # Start processing from the root
2448            process_tree(self._obj, new_tree)
2449
2450            return new_tree
2451
2452        def subset(self, lats, lons):
2453            """
2454            Subset all datasets in the tree to a region.
2455
2456            Parameters
2457            ----------
2458            lats : list or tuple
2459                Latitude bounds [min_lat, max_lat]
2460            lons : list or tuple
2461                Longitude bounds [min_lon, max_lon]
2462
2463            Returns
2464            -------
2465            xarray.DataTree
2466                New DataTree with subset data
2467            """
2468            new_tree = xr.DataTree()
2469
2470            # Function to recursively process the tree
2471            def process_tree(node, new_parent):
2472                # If this is a Dataset node with data variables
2473                if node.ds is not None and node.ds.data_vars:
2474                    # Subset dataset
2475                    subset_ds = node.ds.grib2io.subset(lats, lons)
2476
2477                    # Add to new tree at the same path
2478                    if node == self._obj:  # Root node
2479                        new_parent.ds = subset_ds
2480                    else:
2481                        new_parent.ds = subset_ds
2482
2483                # Process children
2484                for child_name, child_node in node.children.items():
2485                    # Create same child in new tree
2486                    new_child = xr.DataTree()
2487                    new_parent[child_name] = new_child
2488                    process_tree(child_node, new_child)
2489
2490            # Start processing from the root
2491            process_tree(self._obj, new_tree)
2492
2493            return new_tree

DataTree accessor for GRIB2 files.

This accessor provides methods for working with GRIB2 data organized in a hierarchical tree structure.

Grib2ioDataTree(datatree_obj)
2342        def __init__(self, datatree_obj):
2343            self._obj = datatree_obj
def to_grib2(self, filename, mode: Literal['x', 'w', 'a'] = 'x'):
2345        def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
2346            """
2347            Write all datasets in the DataTree to a GRIB2 file.
2348
2349            Parameters
2350            ----------
2351            filename : str
2352                Name of the GRIB2 file to write to.
2353            mode : {"x", "w", "a"}, optional
2354                Persistence mode, default is "x" (create, fail if exists)
2355            """
2356            # Start with the specified mode
2357            current_mode = mode
2358
2359            # Function to recursively process the tree
2360            def process_tree(node):
2361                nonlocal current_mode
2362
2363                # If this is a Dataset node with data variables
2364                if node.ds is not None and node.ds.data_vars:
2365                    # Write dataset to GRIB2 file
2366                    node.ds.grib2io.to_grib2(filename, mode=current_mode)
2367                    # Switch to append mode after first write
2368                    current_mode = "a"
2369
2370                # Process children
2371                for child_name, child_node in node.children.items():
2372                    process_tree(child_node)
2373
2374            # Start processing from the root
2375            process_tree(self._obj)

Write all datasets in the DataTree to a GRIB2 file.

Parameters
  • filename (str): Name of the GRIB2 file to write to.
  • mode ({"x", "w", "a"}, optional): Persistence mode, default is "x" (create, fail if exists)
def griddef(self):
2377        def griddef(self):
2378            """
2379            Get the grid definition from the first dataset in the tree that has one.
2380
2381            Returns
2382            -------
2383            grib2io.Grib2GridDef
2384                Grid definition object
2385            """
2386            # Function to find first dataset with GRIB2IO_section3
2387            def find_griddef(node):
2388                if node.ds is not None and node.ds.data_vars:
2389                    for var_name in node.ds.data_vars:
2390                        if 'GRIB2IO_section3' in node.ds[var_name].attrs:
2391                            return Grib2GridDef.from_section3(node.ds[var_name].attrs['GRIB2IO_section3'])
2392
2393                # Check children
2394                for child_name, child_node in node.children.items():
2395                    griddef = find_griddef(child_node)
2396                    if griddef is not None:
2397                        return griddef
2398
2399                return None
2400
2401            return find_griddef(self._obj)

Get the grid definition from the first dataset in the tree that has one.

Returns
def interp(self, method, grid_def_out, method_options=None, num_threads=1):
2403        def interp(self, method, grid_def_out, method_options=None, num_threads=1):
2404            """
2405            Interpolate all datasets in the tree to a new grid.
2406
2407            Parameters
2408            ----------
2409            method : str or int
2410                Interpolation method to use
2411            grid_def_out : grib2io.Grib2GridDef
2412                Target grid definition
2413            method_options : list, optional
2414                Options for interpolation method
2415            num_threads : int, optional
2416                Number of threads to use for interpolation
2417
2418            Returns
2419            -------
2420            xarray.DataTree
2421                New DataTree with interpolated data
2422            """
2423            new_tree = xr.DataTree()
2424
2425            # Function to recursively process the tree
2426            def process_tree(node, new_parent):
2427                # If this is a Dataset node with data variables
2428                if node.ds is not None and node.ds.data_vars:
2429                    # Interpolate dataset
2430                    interp_ds = node.ds.grib2io.interp(method, grid_def_out,
2431                                                       method_options=method_options,
2432                                                       num_threads=num_threads)
2433
2434                    # Add to new tree at the same path
2435                    if node == self._obj:  # Root node
2436                        new_parent.ds = interp_ds
2437                    else:
2438                        new_parent.ds = interp_ds
2439
2440                # Process children
2441                for child_name, child_node in node.children.items():
2442                    # Create same child in new tree
2443                    new_child = xr.DataTree()
2444                    new_parent[child_name] = new_child
2445                    process_tree(child_node, new_child)
2446
2447            # Start processing from the root
2448            process_tree(self._obj, new_tree)
2449
2450            return new_tree

Interpolate all datasets in the tree to a new grid.

Parameters
  • method (str or int): Interpolation method to use
  • grid_def_out (grib2io.Grib2GridDef): Target grid definition
  • method_options (list, optional): Options for interpolation method
  • num_threads (int, optional): Number of threads to use for interpolation
Returns
  • xarray.DataTree: New DataTree with interpolated data
def subset(self, lats, lons):
2452        def subset(self, lats, lons):
2453            """
2454            Subset all datasets in the tree to a region.
2455
2456            Parameters
2457            ----------
2458            lats : list or tuple
2459                Latitude bounds [min_lat, max_lat]
2460            lons : list or tuple
2461                Longitude bounds [min_lon, max_lon]
2462
2463            Returns
2464            -------
2465            xarray.DataTree
2466                New DataTree with subset data
2467            """
2468            new_tree = xr.DataTree()
2469
2470            # Function to recursively process the tree
2471            def process_tree(node, new_parent):
2472                # If this is a Dataset node with data variables
2473                if node.ds is not None and node.ds.data_vars:
2474                    # Subset dataset
2475                    subset_ds = node.ds.grib2io.subset(lats, lons)
2476
2477                    # Add to new tree at the same path
2478                    if node == self._obj:  # Root node
2479                        new_parent.ds = subset_ds
2480                    else:
2481                        new_parent.ds = subset_ds
2482
2483                # Process children
2484                for child_name, child_node in node.children.items():
2485                    # Create same child in new tree
2486                    new_child = xr.DataTree()
2487                    new_parent[child_name] = new_child
2488                    process_tree(child_node, new_child)
2489
2490            # Start processing from the root
2491            process_tree(self._obj, new_tree)
2492
2493            return new_tree

Subset all datasets in the tree to a region.

Parameters
  • lats (list or tuple): Latitude bounds [min_lat, max_lat]
  • lons (list or tuple): Longitude bounds [min_lon, max_lon]
Returns
  • xarray.DataTree: New DataTree with subset data