diff --git a/dart/collision/RaycastOption.cpp b/dart/collision/RaycastOption.cpp index 829b6dd38bb86..d83a539e9e33c 100644 --- a/dart/collision/RaycastOption.cpp +++ b/dart/collision/RaycastOption.cpp @@ -36,11 +36,23 @@ namespace dart { namespace collision { //============================================================================== -RaycastOption::RaycastOption(bool enableAllHits, bool sortByClosest) - : mEnableAllHits(enableAllHits), mSortByClosest(sortByClosest) +RaycastOption::RaycastOption( + bool enableAllHits, bool sortByClosest, RaycastFilter filter) + : mEnableAllHits(enableAllHits), + mSortByClosest(sortByClosest), + mFilter(std::move(filter)) { // Do nothing } +//============================================================================== +bool RaycastOption::passesFilter(const CollisionObject* object) const +{ + if (!mFilter) + return true; + + return mFilter(object); +} + } // namespace collision } // namespace dart diff --git a/dart/collision/RaycastOption.hpp b/dart/collision/RaycastOption.hpp index 5b0fb1facd5f8..2e17edd2e19e3 100644 --- a/dart/collision/RaycastOption.hpp +++ b/dart/collision/RaycastOption.hpp @@ -35,6 +35,7 @@ #include +#include #include #include @@ -42,16 +43,27 @@ namespace dart { namespace collision { +class CollisionObject; + struct DART_API RaycastOption { + using RaycastFilter = std::function; + /// Constructor - RaycastOption(bool enableAllHits = false, bool sortByClosest = false); + RaycastOption( + bool enableAllHits = false, + bool sortByClosest = false, + RaycastFilter filter = nullptr); + + /// Returns true when the filter is not set or allows the object. + bool passesFilter(const CollisionObject* object) const; bool mEnableAllHits; bool mSortByClosest; - // TODO(JS): Add filter + /// Optional filter to reject hits from specific collision objects. + RaycastFilter mFilter; }; } // namespace collision diff --git a/dart/collision/bullet/BulletCollisionDetector.cpp b/dart/collision/bullet/BulletCollisionDetector.cpp index a9a4c51458ca9..3d9cadf420d18 100644 --- a/dart/collision/bullet/BulletCollisionDetector.cpp +++ b/dart/collision/bullet/BulletCollisionDetector.cpp @@ -377,16 +377,50 @@ bool BulletCollisionDetector::raycast( const auto btFrom = convertVector3(from); const auto btTo = convertVector3(to); - if (option.mEnableAllHits) { + const bool needsAllHits + = option.mEnableAllHits || static_cast(option.mFilter); + + if (needsAllHits) { + auto lessFraction = [](const RayHit& a, const RayHit& b) { + return a.mFraction < b.mFraction; + }; + auto callback = btCollisionWorld::AllHitsRayResultCallback(btFrom, btTo); castedGroup->updateEngineData(); collisionWorld->rayTest(btFrom, btTo, callback); - if (result == nullptr) - return callback.hasHit(); + if (result == nullptr) { + if (!callback.hasHit()) + return false; + + if (!option.mFilter) + return true; + + for (int i = 0; i < callback.m_collisionObjects.size(); ++i) { + const auto* collObj = static_cast( + callback.m_collisionObjects[i]->getUserPointer()); + if (option.passesFilter(collObj)) + return true; + } + + return false; + } if (callback.hasHit()) { reportRayHits(callback, option, *result); + + if (!option.mEnableAllHits && !result->mRayHits.empty()) { + if (option.mSortByClosest) { + result->mRayHits.resize(1); + } else { + const auto closest = std::min_element( + result->mRayHits.begin(), result->mRayHits.end(), lessFraction); + const RayHit closestHit = *closest; + result->mRayHits.clear(); + result->mRayHits.emplace_back(closestHit); + } + } + return result->hasHit(); } else { return false; @@ -765,7 +799,7 @@ RayHit convertRayHit( //============================================================================== void reportRayHits( const btCollisionWorld::ClosestRayResultCallback callback, - const RaycastOption& /*option*/, + const RaycastOption& option, RaycastResult& result) { // This function shouldn't be called if callback has not ray hit. @@ -779,7 +813,9 @@ void reportRayHits( result.mRayHits.clear(); result.mRayHits.reserve(1); - result.mRayHits.emplace_back(rayHit); + + if (option.passesFilter(rayHit.mCollisionObject)) + result.mRayHits.emplace_back(rayHit); } //============================================================================== @@ -807,7 +843,8 @@ void reportRayHits( callback.m_hitPointWorld[i], callback.m_hitNormalWorld[i], callback.m_hitFractions[i]); - result.mRayHits.emplace_back(rayHit); + if (option.passesFilter(rayHit.mCollisionObject)) + result.mRayHits.emplace_back(rayHit); } if (option.mSortByClosest) diff --git a/python/stubs/dartpy/collision.pyi b/python/stubs/dartpy/collision.pyi index ea58d6b7da628..38d758197c096 100644 --- a/python/stubs/dartpy/collision.pyi +++ b/python/stubs/dartpy/collision.pyi @@ -458,6 +458,7 @@ class RayHit: class RaycastOption: mEnableAllHits: bool mSortByClosest: bool + mFilter: typing.Optional[typing.Callable[[CollisionObject], bool]] @typing.overload def __init__(self) -> None: ... @@ -465,7 +466,12 @@ class RaycastOption: def __init__(self, enableAllHits: bool) -> None: ... @typing.overload - def __init__(self, enableAllHits: bool, sortByClosest: bool) -> None: + def __init__( + self, + enableAllHits: bool, + sortByClosest: bool, + filter: typing.Optional[typing.Callable[[CollisionObject], bool]] = None, + ) -> None: ... class RaycastResult: mRayHits: list[RayHit] diff --git a/tests/unit/collision/test_Raycast.cpp b/tests/unit/collision/test_Raycast.cpp index eb262d5840bcb..34e411ce7dc64 100644 --- a/tests/unit/collision/test_Raycast.cpp +++ b/tests/unit/collision/test_Raycast.cpp @@ -45,6 +45,22 @@ using namespace collision; using namespace dynamics; using namespace dart::test; +namespace { + +class DummyCollisionObject : public collision::CollisionObject +{ +public: + DummyCollisionObject( + collision::CollisionDetector* detector, const dynamics::ShapeFrame* frame) + : CollisionObject(detector, frame) + { + } + + void updateEngineData() override {} +}; + +} // namespace + //============================================================================== TEST(Raycast, RayHitDefaultConstructor) { @@ -62,6 +78,45 @@ TEST(Raycast, RaycastResultDefaultConstructor) EXPECT_TRUE(result.mRayHits.empty()); } +//============================================================================== +TEST(Raycast, RaycastOptionPassesFilterWhenUnset) +{ + auto detector = DARTCollisionDetector::create(); + + auto frame = SimpleFrame::createShared(Frame::World()); + frame->setShape(std::make_shared(0.1)); + DummyCollisionObject object(detector.get(), frame.get()); + + RaycastOption option; + + EXPECT_FALSE(option.mEnableAllHits); + EXPECT_FALSE(option.mSortByClosest); + EXPECT_TRUE(option.passesFilter(&object)); +} + +//============================================================================== +TEST(Raycast, RaycastOptionHonorsPredicate) +{ + auto detector = DARTCollisionDetector::create(); + + auto firstFrame = SimpleFrame::createShared(Frame::World()); + firstFrame->setShape(std::make_shared(0.1)); + auto secondFrame = SimpleFrame::createShared(Frame::World()); + secondFrame->setShape(std::make_shared(0.1)); + + DummyCollisionObject first(detector.get(), firstFrame.get()); + DummyCollisionObject second(detector.get(), secondFrame.get()); + + RaycastOption option(true, true, [&](const collision::CollisionObject* obj) { + return obj == &first; + }); + + EXPECT_TRUE(option.mEnableAllHits); + EXPECT_TRUE(option.mSortByClosest); + EXPECT_TRUE(option.passesFilter(&first)); + EXPECT_FALSE(option.passesFilter(&second)); +} + //============================================================================== void testBasicInterface(const std::shared_ptr& cd) { @@ -242,3 +297,90 @@ TEST(Raycast, testOptions) auto dart = DARTCollisionDetector::create(); testOptions(dart); } + +//============================================================================== +void testFilters(const std::shared_ptr& cd) +{ + if (cd->getType() != "bullet") { + DART_WARN( + "Aborting test: distance check is not supported by {}.", cd->getType()); + return; + } + + auto simpleFrame1 = SimpleFrame::createShared(Frame::World()); + auto simpleFrame2 = SimpleFrame::createShared(Frame::World()); + + auto sphere = std::make_shared(1.0); + simpleFrame1->setShape(sphere); + simpleFrame2->setShape(sphere); + + auto group = cd->createCollisionGroup(simpleFrame1.get(), simpleFrame2.get()); + + collision::RaycastOption option; + collision::RaycastResult result; + + simpleFrame1->setTranslation(Eigen::Vector3d(-2, 0, 0)); + simpleFrame2->setTranslation(Eigen::Vector3d(2, 0, 0)); + + option.mEnableAllHits = false; + option.mSortByClosest = false; + option.mFilter = [&](const collision::CollisionObject* obj) { + return obj->getShapeFrame() == simpleFrame2.get(); + }; + + result.clear(); + cd->raycast( + group.get(), + Eigen::Vector3d(-5, 0, 0), + Eigen::Vector3d(5, 0, 0), + option, + &result); + ASSERT_TRUE(result.hasHit()); + ASSERT_EQ(result.mRayHits.size(), 1u); + auto rayHit = result.mRayHits[0]; + EXPECT_TRUE(equals(rayHit.mPoint, Eigen::Vector3d(1, 0, 0))); + EXPECT_TRUE(equals(rayHit.mNormal, Eigen::Vector3d(-1, 0, 0))); + EXPECT_NEAR(rayHit.mFraction, 0.6, 1e-5); + + option.mEnableAllHits = true; + option.mFilter = [&](const collision::CollisionObject* obj) { + return obj->getShapeFrame() == simpleFrame1.get(); + }; + result.clear(); + cd->raycast( + group.get(), + Eigen::Vector3d(-5, 0, 0), + Eigen::Vector3d(5, 0, 0), + option, + &result); + ASSERT_TRUE(result.hasHit()); + ASSERT_EQ(result.mRayHits.size(), 1u); + rayHit = result.mRayHits[0]; + EXPECT_TRUE(equals(rayHit.mPoint, Eigen::Vector3d(-3, 0, 0))); + EXPECT_TRUE(equals(rayHit.mNormal, Eigen::Vector3d(-1, 0, 0))); + EXPECT_NEAR(rayHit.mFraction, 0.2, 1e-5); + + option.mFilter = [&](const collision::CollisionObject*) { + return false; + }; + result.clear(); + cd->raycast( + group.get(), + Eigen::Vector3d(-5, 0, 0), + Eigen::Vector3d(5, 0, 0), + option, + &result); + EXPECT_FALSE(result.hasHit()); + EXPECT_TRUE(result.mRayHits.empty()); +} + +//============================================================================== +TEST(Raycast, testFilters) +{ +#if HAVE_BULLET + auto bullet = BulletCollisionDetector::create(); + testFilters(bullet); +#else + GTEST_SKIP() << "Bullet collision detector not available."; +#endif +}