Skip to content

Commit d11fa33

Browse files
authored
Add raycast filter option (#2279)
1 parent 323ff7a commit d11fa33

File tree

5 files changed

+220
-11
lines changed

5 files changed

+220
-11
lines changed

dart/collision/RaycastOption.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,23 @@ namespace dart {
3636
namespace collision {
3737

3838
//==============================================================================
39-
RaycastOption::RaycastOption(bool enableAllHits, bool sortByClosest)
40-
: mEnableAllHits(enableAllHits), mSortByClosest(sortByClosest)
39+
RaycastOption::RaycastOption(
40+
bool enableAllHits, bool sortByClosest, RaycastFilter filter)
41+
: mEnableAllHits(enableAllHits),
42+
mSortByClosest(sortByClosest),
43+
mFilter(std::move(filter))
4144
{
4245
// Do nothing
4346
}
4447

48+
//==============================================================================
49+
bool RaycastOption::passesFilter(const CollisionObject* object) const
50+
{
51+
if (!mFilter)
52+
return true;
53+
54+
return mFilter(object);
55+
}
56+
4557
} // namespace collision
4658
} // namespace dart

dart/collision/RaycastOption.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,35 @@
3535

3636
#include <dart/Export.hpp>
3737

38+
#include <functional>
3839
#include <memory>
3940

4041
#include <cstddef>
4142

4243
namespace dart {
4344
namespace collision {
4445

46+
class CollisionObject;
47+
4548
struct DART_API RaycastOption
4649
{
50+
using RaycastFilter = std::function<bool(const CollisionObject*)>;
51+
4752
/// Constructor
48-
RaycastOption(bool enableAllHits = false, bool sortByClosest = false);
53+
RaycastOption(
54+
bool enableAllHits = false,
55+
bool sortByClosest = false,
56+
RaycastFilter filter = nullptr);
57+
58+
/// Returns true when the filter is not set or allows the object.
59+
bool passesFilter(const CollisionObject* object) const;
4960

5061
bool mEnableAllHits;
5162

5263
bool mSortByClosest;
5364

54-
// TODO(JS): Add filter
65+
/// Optional filter to reject hits from specific collision objects.
66+
RaycastFilter mFilter;
5567
};
5668

5769
} // namespace collision

dart/collision/bullet/BulletCollisionDetector.cpp

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,16 +377,50 @@ bool BulletCollisionDetector::raycast(
377377
const auto btFrom = convertVector3(from);
378378
const auto btTo = convertVector3(to);
379379

380-
if (option.mEnableAllHits) {
380+
const bool needsAllHits
381+
= option.mEnableAllHits || static_cast<bool>(option.mFilter);
382+
383+
if (needsAllHits) {
384+
auto lessFraction = [](const RayHit& a, const RayHit& b) {
385+
return a.mFraction < b.mFraction;
386+
};
387+
381388
auto callback = btCollisionWorld::AllHitsRayResultCallback(btFrom, btTo);
382389
castedGroup->updateEngineData();
383390
collisionWorld->rayTest(btFrom, btTo, callback);
384391

385-
if (result == nullptr)
386-
return callback.hasHit();
392+
if (result == nullptr) {
393+
if (!callback.hasHit())
394+
return false;
395+
396+
if (!option.mFilter)
397+
return true;
398+
399+
for (int i = 0; i < callback.m_collisionObjects.size(); ++i) {
400+
const auto* collObj = static_cast<BulletCollisionObject*>(
401+
callback.m_collisionObjects[i]->getUserPointer());
402+
if (option.passesFilter(collObj))
403+
return true;
404+
}
405+
406+
return false;
407+
}
387408

388409
if (callback.hasHit()) {
389410
reportRayHits(callback, option, *result);
411+
412+
if (!option.mEnableAllHits && !result->mRayHits.empty()) {
413+
if (option.mSortByClosest) {
414+
result->mRayHits.resize(1);
415+
} else {
416+
const auto closest = std::min_element(
417+
result->mRayHits.begin(), result->mRayHits.end(), lessFraction);
418+
const RayHit closestHit = *closest;
419+
result->mRayHits.clear();
420+
result->mRayHits.emplace_back(closestHit);
421+
}
422+
}
423+
390424
return result->hasHit();
391425
} else {
392426
return false;
@@ -765,7 +799,7 @@ RayHit convertRayHit(
765799
//==============================================================================
766800
void reportRayHits(
767801
const btCollisionWorld::ClosestRayResultCallback callback,
768-
const RaycastOption& /*option*/,
802+
const RaycastOption& option,
769803
RaycastResult& result)
770804
{
771805
// This function shouldn't be called if callback has not ray hit.
@@ -779,7 +813,9 @@ void reportRayHits(
779813

780814
result.mRayHits.clear();
781815
result.mRayHits.reserve(1);
782-
result.mRayHits.emplace_back(rayHit);
816+
817+
if (option.passesFilter(rayHit.mCollisionObject))
818+
result.mRayHits.emplace_back(rayHit);
783819
}
784820

785821
//==============================================================================
@@ -807,7 +843,8 @@ void reportRayHits(
807843
callback.m_hitPointWorld[i],
808844
callback.m_hitNormalWorld[i],
809845
callback.m_hitFractions[i]);
810-
result.mRayHits.emplace_back(rayHit);
846+
if (option.passesFilter(rayHit.mCollisionObject))
847+
result.mRayHits.emplace_back(rayHit);
811848
}
812849

813850
if (option.mSortByClosest)

python/stubs/dartpy/collision.pyi

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,14 +458,20 @@ class RayHit:
458458
class RaycastOption:
459459
mEnableAllHits: bool
460460
mSortByClosest: bool
461+
mFilter: typing.Optional[typing.Callable[[CollisionObject], bool]]
461462
@typing.overload
462463
def __init__(self) -> None:
463464
...
464465
@typing.overload
465466
def __init__(self, enableAllHits: bool) -> None:
466467
...
467468
@typing.overload
468-
def __init__(self, enableAllHits: bool, sortByClosest: bool) -> None:
469+
def __init__(
470+
self,
471+
enableAllHits: bool,
472+
sortByClosest: bool,
473+
filter: typing.Optional[typing.Callable[[CollisionObject], bool]] = None,
474+
) -> None:
469475
...
470476
class RaycastResult:
471477
mRayHits: list[RayHit]

tests/unit/collision/test_Raycast.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ using namespace collision;
4545
using namespace dynamics;
4646
using namespace dart::test;
4747

48+
namespace {
49+
50+
class DummyCollisionObject : public collision::CollisionObject
51+
{
52+
public:
53+
DummyCollisionObject(
54+
collision::CollisionDetector* detector, const dynamics::ShapeFrame* frame)
55+
: CollisionObject(detector, frame)
56+
{
57+
}
58+
59+
void updateEngineData() override {}
60+
};
61+
62+
} // namespace
63+
4864
//==============================================================================
4965
TEST(Raycast, RayHitDefaultConstructor)
5066
{
@@ -62,6 +78,45 @@ TEST(Raycast, RaycastResultDefaultConstructor)
6278
EXPECT_TRUE(result.mRayHits.empty());
6379
}
6480

81+
//==============================================================================
82+
TEST(Raycast, RaycastOptionPassesFilterWhenUnset)
83+
{
84+
auto detector = DARTCollisionDetector::create();
85+
86+
auto frame = SimpleFrame::createShared(Frame::World());
87+
frame->setShape(std::make_shared<SphereShape>(0.1));
88+
DummyCollisionObject object(detector.get(), frame.get());
89+
90+
RaycastOption option;
91+
92+
EXPECT_FALSE(option.mEnableAllHits);
93+
EXPECT_FALSE(option.mSortByClosest);
94+
EXPECT_TRUE(option.passesFilter(&object));
95+
}
96+
97+
//==============================================================================
98+
TEST(Raycast, RaycastOptionHonorsPredicate)
99+
{
100+
auto detector = DARTCollisionDetector::create();
101+
102+
auto firstFrame = SimpleFrame::createShared(Frame::World());
103+
firstFrame->setShape(std::make_shared<SphereShape>(0.1));
104+
auto secondFrame = SimpleFrame::createShared(Frame::World());
105+
secondFrame->setShape(std::make_shared<SphereShape>(0.1));
106+
107+
DummyCollisionObject first(detector.get(), firstFrame.get());
108+
DummyCollisionObject second(detector.get(), secondFrame.get());
109+
110+
RaycastOption option(true, true, [&](const collision::CollisionObject* obj) {
111+
return obj == &first;
112+
});
113+
114+
EXPECT_TRUE(option.mEnableAllHits);
115+
EXPECT_TRUE(option.mSortByClosest);
116+
EXPECT_TRUE(option.passesFilter(&first));
117+
EXPECT_FALSE(option.passesFilter(&second));
118+
}
119+
65120
//==============================================================================
66121
void testBasicInterface(const std::shared_ptr<CollisionDetector>& cd)
67122
{
@@ -242,3 +297,90 @@ TEST(Raycast, testOptions)
242297
auto dart = DARTCollisionDetector::create();
243298
testOptions(dart);
244299
}
300+
301+
//==============================================================================
302+
void testFilters(const std::shared_ptr<CollisionDetector>& cd)
303+
{
304+
if (cd->getType() != "bullet") {
305+
DART_WARN(
306+
"Aborting test: distance check is not supported by {}.", cd->getType());
307+
return;
308+
}
309+
310+
auto simpleFrame1 = SimpleFrame::createShared(Frame::World());
311+
auto simpleFrame2 = SimpleFrame::createShared(Frame::World());
312+
313+
auto sphere = std::make_shared<SphereShape>(1.0);
314+
simpleFrame1->setShape(sphere);
315+
simpleFrame2->setShape(sphere);
316+
317+
auto group = cd->createCollisionGroup(simpleFrame1.get(), simpleFrame2.get());
318+
319+
collision::RaycastOption option;
320+
collision::RaycastResult result;
321+
322+
simpleFrame1->setTranslation(Eigen::Vector3d(-2, 0, 0));
323+
simpleFrame2->setTranslation(Eigen::Vector3d(2, 0, 0));
324+
325+
option.mEnableAllHits = false;
326+
option.mSortByClosest = false;
327+
option.mFilter = [&](const collision::CollisionObject* obj) {
328+
return obj->getShapeFrame() == simpleFrame2.get();
329+
};
330+
331+
result.clear();
332+
cd->raycast(
333+
group.get(),
334+
Eigen::Vector3d(-5, 0, 0),
335+
Eigen::Vector3d(5, 0, 0),
336+
option,
337+
&result);
338+
ASSERT_TRUE(result.hasHit());
339+
ASSERT_EQ(result.mRayHits.size(), 1u);
340+
auto rayHit = result.mRayHits[0];
341+
EXPECT_TRUE(equals(rayHit.mPoint, Eigen::Vector3d(1, 0, 0)));
342+
EXPECT_TRUE(equals(rayHit.mNormal, Eigen::Vector3d(-1, 0, 0)));
343+
EXPECT_NEAR(rayHit.mFraction, 0.6, 1e-5);
344+
345+
option.mEnableAllHits = true;
346+
option.mFilter = [&](const collision::CollisionObject* obj) {
347+
return obj->getShapeFrame() == simpleFrame1.get();
348+
};
349+
result.clear();
350+
cd->raycast(
351+
group.get(),
352+
Eigen::Vector3d(-5, 0, 0),
353+
Eigen::Vector3d(5, 0, 0),
354+
option,
355+
&result);
356+
ASSERT_TRUE(result.hasHit());
357+
ASSERT_EQ(result.mRayHits.size(), 1u);
358+
rayHit = result.mRayHits[0];
359+
EXPECT_TRUE(equals(rayHit.mPoint, Eigen::Vector3d(-3, 0, 0)));
360+
EXPECT_TRUE(equals(rayHit.mNormal, Eigen::Vector3d(-1, 0, 0)));
361+
EXPECT_NEAR(rayHit.mFraction, 0.2, 1e-5);
362+
363+
option.mFilter = [&](const collision::CollisionObject*) {
364+
return false;
365+
};
366+
result.clear();
367+
cd->raycast(
368+
group.get(),
369+
Eigen::Vector3d(-5, 0, 0),
370+
Eigen::Vector3d(5, 0, 0),
371+
option,
372+
&result);
373+
EXPECT_FALSE(result.hasHit());
374+
EXPECT_TRUE(result.mRayHits.empty());
375+
}
376+
377+
//==============================================================================
378+
TEST(Raycast, testFilters)
379+
{
380+
#if HAVE_BULLET
381+
auto bullet = BulletCollisionDetector::create();
382+
testFilters(bullet);
383+
#else
384+
GTEST_SKIP() << "Bullet collision detector not available.";
385+
#endif
386+
}

0 commit comments

Comments
 (0)