-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
It may be useful to make AcceleratedKernels.searchsortedfirst working also with x as a matrix input (while v still a vector input).
The minimum code change required would be:
- change the input type signature in
searchsortedfirst!andsearchsortedfirst; - change the
@argcheckinsearchsortedfirst!.
The complete example is:
using AcceleratedKernels: foreachindex
using KernelAbstractions: get_backend, Backend
using ArgCheck: @argcheck
import AcceleratedKernels as AK
function _searchsortedfirst(v, x, lo::T, hi::T, comp) where T<:Integer
hi = hi + T(1)
len = hi - lo
@inbounds while len != 0x0
half_len = len >>> 0x1
m = lo + half_len
if comp(v[m], x)
lo = m + 0x1
len -= half_len + 0x1
else
hi = m
len = half_len
end
end
return lo
end
function _searchsortedfirst(v, x, lo::T, hi::T, ord::Base.Order.Ordering) where T<:Integer
hi = hi + T(1)
len = hi - lo
@inbounds while len != 0x0
half_len = len >>> 0x1
m = lo + half_len
if Base.Order.lt(ord, v[m], x)
lo = m + 0x1
len -= half_len + 0x1
else
hi = m
len = half_len
end
end
return lo
end
function searchsortedfirst!(
ix::AbstractVecOrMat, # INSTEAD OF AbstractVector
v::AbstractVector,
x::AbstractVecOrMat, # INSTEAD OF AbstractVector
backend::Backend=get_backend(x);
by=identity, lt=isless, rev::Bool=false,
# CPU settings with different default from `foreachindex`
min_elems::Int=1000,
kwargs...
)
# Simple sanity checks
@argcheck size(ix) == size(x) # INSTEAD OF lenght(ix) == lenght(x)
# Construct comparator
ord = Base.Order.ord(lt, by, rev)
comp = (x, y) -> Base.Order.lt(ord, x, y)
foreachindex(
x, backend;
min_elems, kwargs...
) do i
@inbounds ix[i] = _searchsortedfirst(v, x[i], firstindex(v), lastindex(v), comp)
end
end
function searchsortedfirst(
v::AbstractVector,
x::AbstractVecOrMat, # INSTEAD OF AbstractVector
backend::Backend=get_backend(x);
kwargs...
)
ix = similar(x, Int)
searchsortedfirst!(
ix, v, x, backend;
kwargs...
)
ix
end
# `x` vector input (case already implemented)
x_vec = range(0.25, 0.75, length=300) |> collect
v = range(0., 1., length=100) |> collect
ix1 = AK.searchsortedfirst(v, x_vec)
ix2 = Base.searchsortedfirst.(Ref(v), x_vec)
ix3 = searchsortedfirst(v, x_vec)
@show ix1 == ix2 == ix3
# `x` matrix input
x_mat = Matrix(reduce(hcat, [x_vec for _ in 1:5])')
# ix1 = AK.searchsortedfirst(v,x) not working
ix2_mat = Base.searchsortedfirst.(Ref(v), x_mat)
ix3_mat = searchsortedfirst(v, x_mat)
@show ix2_mat == ix3_mat
Output:
ix1 == ix2 == ix3 = true
ix2_mat == ix3_mat = true
Metadata
Metadata
Assignees
Labels
No labels