Skip to content

Commit

Permalink
Merge pull request #668 from pangeo-forge/gcorradini/add_back_context…
Browse files Browse the repository at this point in the history
…_manager

Reuse `FSSpecTarget` Auth Credentials
  • Loading branch information
ranchodeluxe authored Jan 25, 2024
2 parents af3b80f + 6af025e commit e7ffdb8
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 67 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
runner-version: [
"pangeo-forge-runner==0.9.1",
"pangeo-forge-runner==0.9.2",
"pangeo-forge-runner==0.9.3",
]
steps:
- uses: actions/checkout@v4
Expand Down
1 change: 0 additions & 1 deletion examples/feedstock/hrrr_kerchunk_concat_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_ds(store: zarr.storage.FSStore) -> zarr.storage.FSStore:
concat_dims=pattern.concat_dims,
identical_dims=identical_dims,
precombine_inputs=True,
remote_protocol=remote_protocol,
)
| "Test dataset" >> beam.Map(test_ds)
)
2 changes: 1 addition & 1 deletion pangeo_forge_recipes/combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class CombineMultiZarrToZarr(beam.CombineFn):
along a dimension that does not exist in the individual inputs. In this latter
case, precombining adds the additional dimension to the input so that its
dimensionality will match that of the accumulator.
:param storage_options: Storage options dict to pass to the MultiZarrToZarr
:param target_options: Target options dict to pass to the MultiZarrToZarr
"""

Expand Down
21 changes: 19 additions & 2 deletions pangeo_forge_recipes/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import unicodedata
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, replace
from typing import Iterator, Optional, Union
from dataclasses import dataclass, field, replace
from typing import Any, Dict, Iterator, Optional, Union
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

import fsspec
Expand Down Expand Up @@ -87,10 +87,13 @@ class FSSpecTarget(AbstractTarget):
:param fs: The filesystem object we are writing to.
:param root_path: The path under which the target data will be stored.
:param fsspec_kwargs: The fsspec kwargs that can be reused as
`target_options` and `remote_options` for fsspec class instantiation
"""

fs: fsspec.AbstractFileSystem
root_path: str = ""
fsspec_kwargs: Dict[Any, Any] = field(default_factory=dict)

def __truediv__(self, suffix: str) -> FSSpecTarget:
"""
Expand All @@ -106,6 +109,20 @@ def from_url(cls, url: str):
assert len(root_paths) == 1
return cls(fs, root_paths[0])

def get_fsspec_remote_protocol(self):
"""fsspec implementation-specific remote protocal"""
fsspec_protocol = self.fs.protocol
if isinstance(fsspec_protocol, str):
return fsspec_protocol
elif isinstance(fsspec_protocol, tuple):
return fsspec_protocol[0]
elif isinstance(fsspec_protocol, list):
return fsspec_protocol[0]
else:
raise ValueError(
f"could not resolve fsspec protocol '{fsspec_protocol}' from underlying filesystem"
)

def get_mapper(self) -> fsspec.mapping.FSMap:
"""Get a mutable mapping object suitable for storing Zarr data."""
return FSStore(self.root_path, fs=self.fs)
Expand Down
30 changes: 7 additions & 23 deletions pangeo_forge_recipes/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ class CombineReferences(beam.PTransform):
precombine_inputs: bool = False

def expand(self, references: beam.PCollection) -> beam.PCollection:

return references | beam.CombineGlobally(
CombineMultiZarrToZarr(
concat_dims=self.concat_dims,
Expand All @@ -482,9 +481,6 @@ class WriteReference(beam.PTransform, ZarrWriterMixin):
will be appended to this prefix to create a full path.
:param output_file_name: Name to give the output references file
(``.json`` or ``.parquet`` suffix).
:param target_options: Storage options for opening target files
:param remote_options: Storage options for opening remote files
:param remote_protocol: If files are accessed over the network, provide the remote protocol
over which they are accessed. e.g.: "s3", "gcp", "https", etc.
:param mzz_kwargs: Additional kwargs to pass to ``kerchunk.combine.MultiZarrToZarr``.
"""
Expand All @@ -495,9 +491,6 @@ class WriteReference(beam.PTransform, ZarrWriterMixin):
default_factory=RequiredAtRuntimeDefault
)
output_file_name: str = "reference.json"
target_options: Optional[Dict] = field(default_factory=lambda: {"anon": True})
remote_options: Optional[Dict] = field(default_factory=lambda: {"anon": True})
remote_protocol: Optional[str] = None
mzz_kwargs: dict = field(default_factory=dict)

def expand(self, references: beam.PCollection) -> beam.PCollection:
Expand All @@ -506,9 +499,6 @@ def expand(self, references: beam.PCollection) -> beam.PCollection:
full_target=self.get_full_target(),
concat_dims=self.concat_dims,
output_file_name=self.output_file_name,
target_options=self.target_options,
remote_options=self.remote_options,
remote_protocol=self.remote_protocol,
mzz_kwargs=self.mzz_kwargs,
)

Expand All @@ -520,10 +510,6 @@ class WriteCombinedReference(beam.PTransform, ZarrWriterMixin):
:param store_name: Zarr store will be created with this name under ``target_root``.
:param concat_dims: Dimensions along which to concatenate inputs.
:param identical_dims: Dimensions shared among all inputs.
:param target_options: Storage options for opening target files
:param remote_options: Storage options for opening remote files
:param remote_protocol: If files are accessed over the network, provide the remote protocol
over which they are accessed. e.g.: "s3", "gcp", "https", etc.
:param mzz_kwargs: Additional kwargs to pass to ``kerchunk.combine.MultiZarrToZarr``.
:param precombine_inputs: If ``True``, precombine each input with itself, using
``kerchunk.combine.MultiZarrToZarr``, before adding it to the accumulator.
Expand All @@ -543,9 +529,6 @@ class WriteCombinedReference(beam.PTransform, ZarrWriterMixin):
store_name: str
concat_dims: List[str]
identical_dims: List[str]
target_options: Optional[Dict] = field(default_factory=lambda: {"anon": True})
remote_options: Optional[Dict] = field(default_factory=lambda: {"anon": True})
remote_protocol: Optional[str] = None
mzz_kwargs: dict = field(default_factory=dict)
precombine_inputs: bool = False
target_root: Union[str, FSSpecTarget, RequiredAtRuntimeDefault] = field(
Expand All @@ -554,14 +537,18 @@ class WriteCombinedReference(beam.PTransform, ZarrWriterMixin):
output_file_name: str = "reference.json"

def expand(self, references: beam.PCollection) -> beam.PCollection[zarr.storage.FSStore]:
# unpack fsspec options that will be used below for transforms without dep injection
storage_options = self.target_root.fsspec_kwargs # type: ignore[union-attr]
remote_protocol = self.target_root.get_fsspec_remote_protocol() # type: ignore[union-attr]

return (
references
| CombineReferences(
concat_dims=self.concat_dims,
identical_dims=self.identical_dims,
target_options=self.target_options,
remote_options=self.remote_options,
remote_protocol=self.remote_protocol,
target_options=storage_options,
remote_options=storage_options,
remote_protocol=remote_protocol,
mzz_kwargs=self.mzz_kwargs,
precombine_inputs=self.precombine_inputs,
)
Expand All @@ -570,9 +557,6 @@ def expand(self, references: beam.PCollection) -> beam.PCollection[zarr.storage.
concat_dims=self.concat_dims,
target_root=self.target_root,
output_file_name=self.output_file_name,
target_options=self.target_options,
remote_options=self.remote_options,
remote_protocol=self.remote_protocol,
)
)

Expand Down
28 changes: 20 additions & 8 deletions pangeo_forge_recipes/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,24 @@ def write_combined_reference(
full_target: FSSpecTarget,
concat_dims: List[str],
output_file_name: str,
target_options: Optional[Dict] = {"anon": True},
remote_options: Optional[Dict] = {"anon": True},
remote_protocol: Optional[str] = None,
refs_per_component: int = 1000,
mzz_kwargs: Optional[Dict] = None,
) -> zarr.storage.FSStore:
"""Write a kerchunk combined references object to file."""
file_ext = os.path.splitext(output_file_name)[-1]
outpath = full_target._full_path(output_file_name)

import ujson # type: ignore

# unpack fsspec options that will be used below for call sites without dep injection
storage_options = full_target.fsspec_kwargs # type: ignore[union-attr]
remote_protocol = full_target.get_fsspec_remote_protocol() # type: ignore[union-attr]

# If reference is a ReferenceFileSystem, write to json
if isinstance(reference, fsspec.FSMap) and isinstance(reference.fs, ReferenceFileSystem):
reference.fs.save_json(outpath, **remote_options)
# context manager reuses dep injected auth credentials without passing storage options
with full_target.fs.open(outpath, "wb") as f:
f.write(ujson.dumps(reference.fs.references).encode())

elif file_ext == ".parquet":
# Creates empty parquet store to be written to
Expand All @@ -146,8 +151,8 @@ def write_combined_reference(
MultiZarrToZarr(
[reference],
concat_dims=concat_dims,
target_options=target_options,
remote_options=remote_options,
target_options=storage_options,
remote_options=storage_options,
remote_protocol=remote_protocol,
out=out,
**mzz_kwargs,
Expand All @@ -160,8 +165,15 @@ def write_combined_reference(
raise NotImplementedError(f"{file_ext = } not supported.")
return ReferenceFileSystem(
outpath,
target_options=target_options,
remote_options=remote_options,
target_options=storage_options,
# NOTE: `target_protocol` is required here b/c
# fsspec classes are inconsistent about deriving
# protocols if they are not passed. In this case ReferenceFileSystem
# decides how to read a reference based on `target_protocol` before
# it is automagically derived unfortunately
# https://github.com/fsspec/filesystem_spec/blob/master/fsspec/implementations/reference.py#L650-L663
target_protocol=remote_protocol,
remote_options=storage_options,
remote_protocol=remote_protocol,
lazy=True,
).get_mapper()
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ test = [
"pytest-sugar",
"pytest-timeout",
"s3fs",
"gcsfs",
"scipy",
]

Expand Down Expand Up @@ -76,7 +77,7 @@ line-length = 100

[tool.isort]
known_first_party = "pangeo_forge_recipes"
known_third_party = ["aiohttp", "apache_beam", "cftime", "click", "dask", "fsspec", "kerchunk", "numpy", "pandas", "pytest", "pytest_lazyfixture", "s3fs", "xarray", "zarr"]
known_third_party = ["aiohttp", "apache_beam", "cftime", "click", "dask", "fsspec", "gcsfs", "kerchunk", "numpy", "packaging", "pandas", "pytest", "pytest_lazyfixture", "s3fs", "xarray", "zarr"]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
Expand Down
6 changes: 0 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,6 @@ def tmp_target(tmpdir_factory):
return FSSpecTarget(fs, path)


@pytest.fixture()
def tmp_target_url(tmpdir_factory):
path = str(tmpdir_factory.mktemp("target.zarr"))
return path


@pytest.fixture()
def tmp_cache(tmpdir_factory):
path = str(tmpdir_factory.mktemp("cache"))
Expand Down
24 changes: 12 additions & 12 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_xarray_zarr(
daily_xarray_dataset,
netcdf_local_file_pattern,
pipeline,
tmp_target_url,
tmp_target,
target_chunks,
):
pattern = netcdf_local_file_pattern
Expand All @@ -46,14 +46,14 @@ def test_xarray_zarr(
| beam.Create(pattern.items())
| OpenWithXarray(file_type=pattern.file_type)
| StoreToZarr(
target_root=tmp_target_url,
target_root=tmp_target,
store_name="store",
target_chunks=target_chunks,
combine_dims=pattern.combine_dim_keys,
)
)

ds = xr.open_dataset(os.path.join(tmp_target_url, "store"), engine="zarr")
ds = xr.open_dataset(os.path.join(tmp_target.root_path, "store"), engine="zarr")
assert ds.time.encoding["chunks"] == (target_chunks["time"],)
xr.testing.assert_equal(ds.load(), daily_xarray_dataset)

Expand All @@ -62,7 +62,7 @@ def test_xarray_zarr_subpath(
daily_xarray_dataset,
netcdf_local_file_pattern_sequential,
pipeline,
tmp_target_url,
tmp_target,
):
pattern = netcdf_local_file_pattern_sequential
with pipeline as p:
Expand All @@ -71,13 +71,13 @@ def test_xarray_zarr_subpath(
| beam.Create(pattern.items())
| OpenWithXarray(file_type=pattern.file_type)
| StoreToZarr(
target_root=tmp_target_url,
target_root=tmp_target,
store_name="subpath",
combine_dims=pattern.combine_dim_keys,
)
)

ds = xr.open_dataset(os.path.join(tmp_target_url, "subpath"), engine="zarr")
ds = xr.open_dataset(os.path.join(tmp_target.root_path, "subpath"), engine="zarr")
xr.testing.assert_equal(ds.load(), daily_xarray_dataset)


Expand All @@ -86,7 +86,7 @@ def test_reference_netcdf(
daily_xarray_dataset,
netcdf_local_file_pattern_sequential,
pipeline,
tmp_target_url,
tmp_target,
output_file_name,
):
pattern = netcdf_local_file_pattern_sequential
Expand All @@ -98,13 +98,13 @@ def test_reference_netcdf(
| OpenWithKerchunk(file_type=pattern.file_type)
| WriteCombinedReference(
identical_dims=["lat", "lon"],
target_root=tmp_target_url,
target_root=tmp_target,
store_name=store_name,
concat_dims=["time"],
output_file_name=output_file_name,
)
)
full_path = os.path.join(tmp_target_url, store_name, output_file_name)
full_path = os.path.join(tmp_target.root_path, store_name, output_file_name)
file_ext = os.path.splitext(output_file_name)[-1]
if file_ext == ".json":
mapper = fsspec.get_mapper("reference://", fo=full_path)
Expand All @@ -130,7 +130,7 @@ def test_reference_netcdf(
)
def test_reference_grib(
pipeline,
tmp_target_url,
tmp_target,
):
# This test adapted from:
# https://github.com/fsspec/kerchunk/blob/33b00d60d02b0da3f05ccee70d6ebc42d8e09932/kerchunk/tests/test_grib.py#L14-L31
Expand All @@ -148,11 +148,11 @@ def test_reference_grib(
| WriteCombinedReference(
concat_dims=[pattern.concat_dims[0]],
identical_dims=["latitude", "longitude"],
target_root=tmp_target_url,
target_root=tmp_target,
store_name=store_name,
)
)
full_path = os.path.join(tmp_target_url, store_name, "reference.json")
full_path = os.path.join(tmp_target.root_path, store_name, "reference.json")
mapper = fsspec.get_mapper("reference://", fo=full_path)
ds = xr.open_dataset(mapper, engine="zarr", backend_kwargs={"consolidated": False})
assert ds.attrs["GRIB_centre"] == "cwao"
Expand Down
6 changes: 6 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import secrets
import subprocess
import time
from importlib.metadata import version
from pathlib import Path

import pytest
from packaging.version import parse as parse_version

# Run only when the `--run-integration` option is passed.
# See also `pytest_addoption` in conftest. Reference:
Expand Down Expand Up @@ -119,6 +121,10 @@ def test_integration(confpath_option: str, recipe_id: str, request):
if recipe_id in xfails:
pytest.xfail(xfails[recipe_id])

runner_version = parse_version(version("pangeo-forge-runner"))
if recipe_id == "hrrr-kerchunk-concat-step" and runner_version <= parse_version("0.9.2"):
pytest.xfail("pg-runner version <= 0.9.2 didn't pass storage options")

confpath = request.getfixturevalue(confpath_option)

bake_script = (EXAMPLES / "runner-commands" / "bake.sh").absolute().as_posix()
Expand Down
Loading

0 comments on commit e7ffdb8

Please sign in to comment.