Skip to content

Commit 9c976a7

Browse files
authored
Merge branch 'main' into revert-precomp-changes-to-check-ci
2 parents 1de772a + 7a8a9a8 commit 9c976a7

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

cpp/src/kmeans/kmeans_predict.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ void predict_impl(const raft::handle_t& handle,
2626
idx_t* labels,
2727
value_t& inertia)
2828
{
29-
auto X_view = raft::make_device_matrix_view(X, n_samples, n_features);
30-
std::optional<raft::device_vector_view<const value_t>> sw = std::nullopt;
29+
auto X_view = raft::make_device_matrix_view<const value_t, idx_t>(X, n_samples, n_features);
30+
std::optional<raft::device_vector_view<const value_t, idx_t>> sw = std::nullopt;
3131
if (sample_weight != nullptr)
3232
sw = std::make_optional(
3333
raft::make_device_vector_view<const value_t, idx_t>(sample_weight, n_samples));

cpp/src/kmeans/kmeans_transform.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ void transform_impl(const raft::handle_t& handle,
2222
idx_t n_features,
2323
value_t* X_new)
2424
{
25-
auto X_view = raft::make_device_matrix_view<const value_t, idx_t>(X, n_samples, n_features);
25+
auto X_view = raft::make_device_matrix_view<const value_t, int>(X, n_samples, n_features);
2626
auto centroids_view =
27-
raft::make_device_matrix_view<const value_t, idx_t>(centroids, params.n_clusters, n_features);
28-
auto rX_new = raft::make_device_matrix_view<value_t, idx_t>(X_new, n_samples, n_features);
27+
raft::make_device_matrix_view<const value_t, int>(centroids, params.n_clusters, n_features);
28+
auto rX_new = raft::make_device_matrix_view<value_t, int>(X_new, n_samples, n_features);
2929

3030
cuvs::cluster::kmeans::transform(handle, params.to_cuvs(), X_view, centroids_view, rX_new);
3131
}

cpp/src/knn/knn.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ void brute_force_knn(const raft::handle_t& handle,
154154
raft::make_device_matrix_view<const int64_t, int64_t>(out_I, n * input.size(), k),
155155
raft::make_device_matrix_view<float, int64_t>(res_D, n, k),
156156
raft::make_device_matrix_view<int64_t, int64_t>(res_I, n, k),
157-
raft::make_device_vector_view<int64_t>(trans.data(), trans.size()));
157+
raft::make_device_vector_view<int64_t, int64_t>(trans.data(), trans.size()));
158158
}
159159

160160
if (translations == nullptr) delete id_ranges;

cpp/src/knn/knn_opg_common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ void reduce(opg_knn_param<in_t, ind_t, dist_t, out_t>& params,
714714
work.res_I.data(), batch_size * work.idxRanks.size(), params.k),
715715
raft::make_device_matrix_view<dist_t, int64_t>(distances, batch_size, params.k),
716716
raft::make_device_matrix_view<ind_t, int64_t>(indices, batch_size, params.k),
717-
raft::make_device_vector_view<trans_t>(trans.data(), trans.size()));
717+
raft::make_device_vector_view<trans_t, int64_t>(trans.data(), trans.size()));
718718
handle.sync_stream(handle.get_stream());
719719
RAFT_CUDA_TRY(cudaPeekAtLastError());
720720

0 commit comments

Comments
 (0)