diff --git a/include/xsimd/arch/xsimd_neon.hpp b/include/xsimd/arch/xsimd_neon.hpp index 734e920f5..455431a0d 100644 --- a/include/xsimd/arch/xsimd_neon.hpp +++ b/include/xsimd/arch/xsimd_neon.hpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -2371,6 +2372,15 @@ namespace xsimd return bitwise_lshift(lhs, n, ::xsimd::detail::int_sequence()); } } + + template + XSIMD_INLINE bool shifts_all_positive(batch const& b) noexcept + { + std::array::size> tmp = {}; + b.store_unaligned(tmp.begin()); + return std::all_of(tmp.begin(), tmp.end(), [](T x) + { return x >= 0; }); + } } template @@ -2382,9 +2392,11 @@ namespace xsimd } template = 0> - XSIMD_INLINE batch bitwise_lshift(batch const& lhs, batch, A> const& rhs, requires_arch) noexcept + XSIMD_INLINE batch bitwise_lshift(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return vshlq_u8(lhs, rhs); + // Blindly converting to signed since out of bounds shifts are UB anyways + assert(detail::shifts_all_positive(rhs)); + return vshlq_u8(lhs, vreinterpretq_s8_u8(rhs)); } template = 0> @@ -2394,9 +2406,11 @@ namespace xsimd } template = 0> - XSIMD_INLINE batch bitwise_lshift(batch const& lhs, batch, A> const& rhs, requires_arch) noexcept + XSIMD_INLINE batch bitwise_lshift(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return vshlq_u16(lhs, rhs); + // Blindly converting to signed since out of bounds shifts are UB anyways + assert(detail::shifts_all_positive(rhs)); + return vshlq_u16(lhs, vreinterpretq_s16_u16(rhs)); } template = 0> @@ -2406,9 +2420,11 @@ namespace xsimd } template = 0> - XSIMD_INLINE batch bitwise_lshift(batch const& lhs, batch, A> const& rhs, requires_arch) noexcept + XSIMD_INLINE batch bitwise_lshift(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return vshlq_u32(lhs, rhs); + // Blindly converting to signed since out of bounds shifts are UB anyways + assert(detail::shifts_all_positive(rhs)); + return vshlq_u32(lhs, vreinterpretq_s32_u32(rhs)); } template = 0> @@ -2418,9 +2434,11 @@ namespace xsimd } template = 0> - XSIMD_INLINE batch bitwise_lshift(batch const& lhs, batch, A> const& rhs, requires_arch) noexcept + XSIMD_INLINE batch bitwise_lshift(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return vshlq_u64(lhs, rhs); + // Blindly converting to signed since out of bounds shifts are UB + assert(detail::shifts_all_positive(rhs)); + return vshlq_u64(lhs, vreinterpretq_s64_u64(rhs)); } template = 0> @@ -2618,9 +2636,11 @@ namespace xsimd } template = 0> - XSIMD_INLINE batch bitwise_rshift(batch const& lhs, batch, A> const& rhs, requires_arch) noexcept + XSIMD_INLINE batch bitwise_rshift(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return vshlq_u8(lhs, vnegq_s8(rhs)); + // Blindly converting to signed since out of bounds shifts are UB anyways + assert(detail::shifts_all_positive(rhs)); + return vshlq_u8(lhs, vnegq_s8(vreinterpretq_s8_u8(rhs))); } template = 0> @@ -2630,9 +2650,11 @@ namespace xsimd } template = 0> - XSIMD_INLINE batch bitwise_rshift(batch const& lhs, batch, A> const& rhs, requires_arch) noexcept + XSIMD_INLINE batch bitwise_rshift(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return vshlq_u16(lhs, vnegq_s16(rhs)); + // Blindly converting to signed since out of bounds shifts are UB anyways + assert(detail::shifts_all_positive(rhs)); + return vshlq_u16(lhs, vnegq_s16(vreinterpretq_s16_u16(rhs))); } template = 0> @@ -2642,9 +2664,11 @@ namespace xsimd } template = 0> - XSIMD_INLINE batch bitwise_rshift(batch const& lhs, batch, A> const& rhs, requires_arch) noexcept + XSIMD_INLINE batch bitwise_rshift(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return vshlq_u32(lhs, vnegq_s32(rhs)); + // Blindly converting to signed since out of bounds shifts are UB anyways + assert(detail::shifts_all_positive(rhs)); + return vshlq_u32(lhs, vnegq_s32(vreinterpretq_s32_u32(rhs))); } template = 0> @@ -2654,9 +2678,12 @@ namespace xsimd } template = 0> - XSIMD_INLINE batch bitwise_rshift(batch const& lhs, batch, A> const& rhs, requires_arch) noexcept + XSIMD_INLINE batch bitwise_rshift(batch const& lhs, batch const& rhs, requires_arch req) noexcept { - return vshlq_u64(lhs, neg(rhs, neon {}).data); + // Blindly converting to signed since out of bounds shifts are UB anyways + assert(detail::shifts_all_positive(rhs)); + using S = std::make_signed_t; + return vshlq_u64(lhs, neg(batch(vreinterpretq_s64_u64(rhs)), req).data); } template = 0> diff --git a/include/xsimd/arch/xsimd_neon64.hpp b/include/xsimd/arch/xsimd_neon64.hpp index 9f3c4bce8..edf04f0e0 100644 --- a/include/xsimd/arch/xsimd_neon64.hpp +++ b/include/xsimd/arch/xsimd_neon64.hpp @@ -12,6 +12,7 @@ #ifndef XSIMD_NEON64_HPP #define XSIMD_NEON64_HPP +#include #include #include #include @@ -1209,9 +1210,11 @@ namespace xsimd } template = 0> - XSIMD_INLINE batch bitwise_rshift(batch const& lhs, batch, A> const& rhs, requires_arch) noexcept + XSIMD_INLINE batch bitwise_rshift(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return vshlq_u64(lhs, vnegq_s64(rhs)); + // Blindly converting to signed since out of bounds shifts are UB anyways + assert(detail::shifts_all_positive(rhs)); + return vshlq_u64(lhs, vnegq_s64(vreinterpretq_s64_u64(rhs))); } template = 0>