Skip to content

Use matrix input in searchsortedfirst #61

@mtagliazucchi

Description

@mtagliazucchi

It may be useful to make AcceleratedKernels.searchsortedfirst working also with x as a matrix input (while v still a vector input).

The minimum code change required would be:

  1. change the input type signature in searchsortedfirst! and searchsortedfirst;
  2. change the @argcheck in searchsortedfirst!.

The complete example is:

using AcceleratedKernels: foreachindex
using KernelAbstractions: get_backend, Backend
using ArgCheck: @argcheck

import AcceleratedKernels as AK

function _searchsortedfirst(v, x, lo::T, hi::T, comp) where T<:Integer
  hi = hi + T(1)
  len = hi - lo
  @inbounds while len != 0x0
    half_len = len >>> 0x1
    m = lo + half_len
    if comp(v[m], x)
      lo = m + 0x1
      len -= half_len + 0x1
    else
      hi = m
      len = half_len
    end
  end
  return lo
end

function _searchsortedfirst(v, x, lo::T, hi::T, ord::Base.Order.Ordering) where T<:Integer
  hi = hi + T(1)
  len = hi - lo
  @inbounds while len != 0x0
    half_len = len >>> 0x1
    m = lo + half_len
    if Base.Order.lt(ord, v[m], x)
      lo = m + 0x1
      len -= half_len + 0x1
    else
      hi = m
      len = half_len
    end
  end
  return lo
end


function searchsortedfirst!(
  ix::AbstractVecOrMat, # INSTEAD OF AbstractVector
  v::AbstractVector,
  x::AbstractVecOrMat, # INSTEAD OF AbstractVector
  backend::Backend=get_backend(x);

  by=identity, lt=isless, rev::Bool=false,

  # CPU settings with different default from `foreachindex`
  min_elems::Int=1000,
  kwargs...
)
  # Simple sanity checks
  @argcheck size(ix) == size(x) # INSTEAD OF lenght(ix) == lenght(x)

  # Construct comparator
  ord = Base.Order.ord(lt, by, rev)
  comp = (x, y) -> Base.Order.lt(ord, x, y)

  foreachindex(
    x, backend;
    min_elems, kwargs...
  ) do i
      @inbounds ix[i] = _searchsortedfirst(v, x[i], firstindex(v), lastindex(v), comp)
  end
end

function searchsortedfirst(
  v::AbstractVector,
  x::AbstractVecOrMat, # INSTEAD OF AbstractVector
  backend::Backend=get_backend(x);
  kwargs...
)
  ix = similar(x, Int)
  searchsortedfirst!(
    ix, v, x, backend;
    kwargs...
  )
  ix
end

# `x` vector input (case already  implemented)

x_vec = range(0.25, 0.75, length=300) |> collect 
v = range(0., 1., length=100) |> collect

ix1 = AK.searchsortedfirst(v, x_vec)
ix2 = Base.searchsortedfirst.(Ref(v), x_vec) 
ix3 = searchsortedfirst(v, x_vec)

@show ix1 == ix2 == ix3

# `x` matrix input

x_mat = Matrix(reduce(hcat, [x_vec for _ in 1:5])')

# ix1 = AK.searchsortedfirst(v,x) not working 
ix2_mat = Base.searchsortedfirst.(Ref(v), x_mat)
ix3_mat = searchsortedfirst(v, x_mat)

@show ix2_mat == ix3_mat

Output:

ix1 == ix2 == ix3 = true
ix2_mat == ix3_mat = true

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions