diff --git a/httomo/data/dataset_store.py b/httomo/data/dataset_store.py index 139789373..a7975d7cd 100644 --- a/httomo/data/dataset_store.py +++ b/httomo/data/dataset_store.py @@ -34,13 +34,14 @@ DataSetSource, ReadableDataSetSink, ) +from httomo.utils import make_pinned_host_array from mpi4py import MPI from mpi4py.util import dtlib import numpy as np from numpy.typing import DTypeLike import weakref -from httomo.utils import log_once, make_3d_shape_from_shape +from httomo.utils import log_once, make_3d_shape_from_shape, gpu_enabled class DataSetStoreWriter(ReadableDataSetSink): @@ -384,7 +385,7 @@ def _read_block_file( self, shape: List[int], dim: int, start_idx: List[int] ) -> np.ndarray: start_idx[dim] += self._global_index[dim] - self._padding[0] - block_data = np.empty(shape, dtype=self._data.dtype) + block_data = make_pinned_host_array(shape, dtype=self._data.dtype) before_cut = 0 after_cut = 0 # check before boundary @@ -547,7 +548,13 @@ def _read_block_ram( slice(start_idx[1], start_idx[1] + shape[1]), slice(start_idx[2], start_idx[2] + shape[2]), ] - return self._data[read_slices[0], read_slices[1], read_slices[2]] + data_slice = self._data[read_slices[0], read_slices[1], read_slices[2]] + if gpu_enabled: + block_data = make_pinned_host_array(shape, self._data.dtype) + block_data[:] = data_slice + return block_data + else: + return data_slice def read_block(self, start: int, length: int) -> DataSetBlock: shape = list(self._global_shape) diff --git a/httomo/loaders/standard_tomo_loader.py b/httomo/loaders/standard_tomo_loader.py index 09b2b508b..8e6d2abfe 100644 --- a/httomo/loaders/standard_tomo_loader.py +++ b/httomo/loaders/standard_tomo_loader.py @@ -21,7 +21,7 @@ from httomo.runner.dataset_store_interfaces import DataSetSource from httomo.runner.loader import LoaderInterface from httomo.types import generic_array -from httomo.utils import log_once, make_3d_shape_from_shape +from httomo.utils import log_once, make_3d_shape_from_shape, make_pinned_host_array from httomo_backends.methods_database.query import Pattern @@ -172,7 +172,7 @@ def read_block(self, start: int, length: int) -> DataSetBlock: start_idx[self._slicing_dim] += start + self._chunk_index[self._slicing_dim] block_shape = list(self.global_shape) block_shape[self._slicing_dim] = length + self._padding[0] + self._padding[1] - block_data = np.empty(block_shape, dtype=self._data.dtype) + block_data = make_pinned_host_array(block_shape, dtype=self._data.dtype) # Bools that reflect if an extended read is needed on either the lower or upper # boundary of the block, in order to fill in the before/after padded areas diff --git a/httomo/utils.py b/httomo/utils.py index 22ec483d5..3c131c41b 100644 --- a/httomo/utils.py +++ b/httomo/utils.py @@ -389,3 +389,15 @@ def search_max_slices_iterative( slices_low = current_slices return slices_low + + +def make_pinned_host_array(shape, dtype) -> np.ndarray: + if gpu_enabled: + pinned_ptr = xp.cuda.alloc_pinned_memory( + np.prod(shape) * np.dtype(dtype).itemsize + ) + return np.frombuffer(pinned_ptr, dtype=dtype, count=np.prod(shape)).reshape( + shape + ) + else: + return np.empty(shape, dtype=dtype)