Skip to content

Commit 35fc0a3

Browse files
jinsolpjcrist
andauthored
Update python/cuml/cuml/manifold/umap/umap.pyx
Co-authored-by: Jim Crist-Harif <[email protected]>
1 parent 256a7ce commit 35fc0a3

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

python/cuml/cuml/manifold/umap/umap.pyx

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -931,11 +931,13 @@ class UMAP(Base, InteropMixin, CMajorInputTagMixin, SparseInputTagMixin):
931931
cdef uintptr_t X_ptr = 0, X_indices_ptr = 0, X_indptr_ptr = 0
932932
cdef size_t X_nnz = 0
933933

934-
mem_type = MemoryType.device
935-
if knn_graph is not None or self.precomputed_knn is not None:
936-
"""Mirrors the input data memory type to avoid unnecessary copies,
937-
since the data itself is not needed when precomputed KNN results are provided"""
938-
mem_type = False
934+
# Don't coerce to device memory when using a precomputed KNN, so
935+
# that X may be dropped earlier if passed on host.
936+
mem_type = (
937+
MemoryType.device
938+
if knn_graph is None and self.precomputed_knn is None
939+
else False
940+
)
939941

940942
if X_is_sparse:
941943
X_m = SparseCumlArray(X, convert_to_dtype=cp.float32, convert_to_mem_type=mem_type)

0 commit comments

Comments
 (0)