Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Bug Fixes
- Fix a major performance regression in :py:meth:`Coordinates.to_index` (and
consequently :py:meth:`Dataset.to_dataframe`) caused by converting the cached
code ndarrays into Python lists (:issue:`11305`).
- Fixed pickling of datasets opened from file-like objects with the scipy
backend after multiple opens (:issue:`11323`).


Documentation
Expand Down
56 changes: 29 additions & 27 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,33 +143,35 @@ def _open_scipy_netcdf(
) -> scipy.io.netcdf_file:
import scipy.io

# TODO: Remove this after upstreaming these fixes.
class flush_only_netcdf_file(scipy.io.netcdf_file):
# scipy.io.netcdf_file.close() incorrectly closes file objects that
# were passed in as constructor arguments:
# https://github.com/scipy/scipy/issues/13905

# Instead of closing such files, only call flush(), which is
# equivalent as long as the netcdf_file object is not mmapped.
# This suffices to keep BytesIO objects open long enough to read
# their contents from to_netcdf(), but underlying files still get
# closed when the netcdf_file is garbage collected (via __del__),
# and will need to be fixed upstream in scipy.
def close(self):
if hasattr(self, "fp") and not self.fp.closed:
self.flush()
self.fp.seek(0) # allow file to be read again

def __del__(self):
# Remove the __del__ method, which in scipy is aliased to close().
# These files need to be closed explicitly by xarray.
pass

_PickleWorkaround.add_cls(flush_only_netcdf_file)

netcdf_file = (
_PickleWorkaround.flush_only_netcdf_file if flush_only else scipy.io.netcdf_file
)
if flush_only:
if not hasattr(_PickleWorkaround, "flush_only_netcdf_file"):
# TODO: Remove this after upstreaming these fixes.
class flush_only_netcdf_file(scipy.io.netcdf_file):
# scipy.io.netcdf_file.close() incorrectly closes file objects that
# were passed in as constructor arguments:
# https://github.com/scipy/scipy/issues/13905

# Instead of closing such files, only call flush(), which is
# equivalent as long as the netcdf_file object is not mmapped.
# This suffices to keep BytesIO objects open long enough to read
# their contents from to_netcdf(), but underlying files still get
# closed when the netcdf_file is garbage collected (via __del__),
# and will need to be fixed upstream in scipy.
def close(self):
if hasattr(self, "fp") and not self.fp.closed:
self.flush()
self.fp.seek(0) # allow file to be read again

def __del__(self):
# Remove the __del__ method, which in scipy is aliased to close().
# These files need to be closed explicitly by xarray.
pass

_PickleWorkaround.add_cls(flush_only_netcdf_file)

netcdf_file = _PickleWorkaround.flush_only_netcdf_file
else:
netcdf_file = scipy.io.netcdf_file

# if the string ends with .gz, then gunzip and open as netcdf file
if isinstance(filename, str) and filename.endswith(".gz"):
Expand Down
18 changes: 18 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4735,6 +4735,24 @@ def test_pickle(self) -> None:
def test_pickle_dataarray(self) -> None:
super().test_pickle_dataarray()

def test_pickle_open_dataset_from_separate_file_objects(self) -> None:
original = Dataset({"foo": ("x", [1, 2, 3])})
fobj = BytesIO()
original.to_netcdf(fobj, engine="scipy")
payload = fobj.getvalue()

ds1 = ds2 = None
try:
ds1 = open_dataset(BytesIO(payload), engine="scipy")
ds2 = open_dataset(BytesIO(payload), engine="scipy")
with pickle.loads(pickle.dumps(ds1)) as unpickled:
assert_identical(unpickled, original)
finally:
if ds1 is not None:
ds1.close()
if ds2 is not None:
ds2.close()

@pytest.mark.parametrize("create_default_indexes", [True, False])
def test_create_default_indexes(self, tmp_path, create_default_indexes) -> None:
store_path = tmp_path / "tmp.nc"
Expand Down
Loading