Skip to content

Support Per-node DataTree chunking #10105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from xarray.backends.locks import _get_scheduler
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
from xarray.core import indexing
from xarray.core.chunk import _get_chunk, _maybe_chunk
from xarray.core.chunk import _get_chunk, _maybe_chunk, _maybe_get_path_chunk
from xarray.core.combine import (
_infer_concat_order_from_positions,
_nested_combine,
Expand Down Expand Up @@ -450,7 +450,7 @@ def _datatree_from_backend_datatree(
node.dataset,
filename_or_obj,
engine,
chunks,
_maybe_get_path_chunk(node.path, chunks),
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
Expand Down
12 changes: 12 additions & 0 deletions xarray/core/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,15 @@ def _maybe_chunk(
return var
else:
return var


def _maybe_get_path_chunk(path: str, chunks: int | dict | Any) -> int | dict | Any:
"""Returns path-specific chunks from a chunks dictionary, if path is a key of chunks.
Otherwise, returns chunks as is"""
if isinstance(chunks, dict):
try:
return chunks[path]
except KeyError:
pass

return chunks
4 changes: 3 additions & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def copy(
T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim]
T_ChunksFreq: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDimFreq]
# We allow the tuple form of this (though arguably we could transition to named dims only)
T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim]
T_Chunks: TypeAlias = (
T_ChunkDim | Mapping[Any, T_ChunkDim] | Mapping[Any, Mapping[Any, T_ChunkDim]]
)
T_NormalizedChunks = tuple[tuple[int, ...], ...]

DataVars = Mapping[Any, Any]
Expand Down
61 changes: 61 additions & 0 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,37 @@ def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None:

assert_chunks_equal(tree, original_tree, enforce_dask=True)

@requires_dask
def test_open_datatree_path_chunks(self, tmpdir, simple_datatree) -> None:
filepath = tmpdir / "test.nc"

root_chunks = {"x": 2, "y": 1}
set1_chunks = {"x": 1, "y": 2}
set2_chunks = {"x": 2, "y": 3}

root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
original_tree = DataTree.from_dict(
{
"/": root_data.chunk(root_chunks),
"/group1": set1_data.chunk(set1_chunks),
"/group2": set2_data.chunk(set2_chunks),
}
)
original_tree.to_netcdf(filepath, engine="netcdf4")

chunks = {
"/": root_chunks,
"/group1": set1_chunks,
"/group2": set2_chunks,
}

with open_datatree(filepath, engine="netcdf4", chunks=chunks) as tree:
xr.testing.assert_identical(tree, original_tree)

assert_chunks_equal(tree, original_tree, enforce_dask=True)

def test_open_groups(self, unaligned_datatree_nc) -> None:
"""Test `open_groups` with a netCDF4 file with an unaligned group hierarchy."""
unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc)
Expand Down Expand Up @@ -549,6 +580,36 @@ def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None:
# from each node.
xr.testing.assert_identical(tree.compute(), original_tree)

def test_open_datatree_path_chunks(self, tmpdir, simple_datatree) -> None:
filepath = tmpdir / "test.zarr"

root_chunks = {"x": 2, "y": 1}
set1_chunks = {"x": 1, "y": 2}
set2_chunks = {"x": 2, "y": 3}

root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
original_tree = DataTree.from_dict(
{
"/": root_data.chunk(root_chunks),
"/group1": set1_data.chunk(set1_chunks),
"/group2": set2_data.chunk(set2_chunks),
}
)
original_tree.to_zarr(filepath)

chunks = {
"/": root_chunks,
"/group1": set1_chunks,
"/group2": set2_chunks,
}

with open_datatree(filepath, engine="zarr", chunks=chunks) as tree:
xr.testing.assert_identical(tree, original_tree)
assert_chunks_equal(tree, original_tree, enforce_dask=True)
xr.testing.assert_identical(tree.compute(), original_tree)

def test_open_groups(self, unaligned_datatree_zarr) -> None:
"""Test `open_groups` with a zarr store of an unaligned group hierarchy."""

Expand Down
Loading