Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions dart/collision/RaycastOption.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,23 @@ namespace dart {
namespace collision {

//==============================================================================
RaycastOption::RaycastOption(bool enableAllHits, bool sortByClosest)
: mEnableAllHits(enableAllHits), mSortByClosest(sortByClosest)
RaycastOption::RaycastOption(
bool enableAllHits, bool sortByClosest, RaycastFilter filter)
: mEnableAllHits(enableAllHits),
mSortByClosest(sortByClosest),
mFilter(std::move(filter))
{
// Do nothing
}

//==============================================================================
bool RaycastOption::passesFilter(const CollisionObject* object) const
{
if (!mFilter)
return true;

return mFilter(object);
}

} // namespace collision
} // namespace dart
16 changes: 14 additions & 2 deletions dart/collision/RaycastOption.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,35 @@

#include <dart/Export.hpp>

#include <functional>
#include <memory>

#include <cstddef>

namespace dart {
namespace collision {

class CollisionObject;

struct DART_API RaycastOption
{
using RaycastFilter = std::function<bool(const CollisionObject*)>;

/// Constructor
RaycastOption(bool enableAllHits = false, bool sortByClosest = false);
RaycastOption(
bool enableAllHits = false,
bool sortByClosest = false,
RaycastFilter filter = nullptr);

/// Returns true when the filter is not set or allows the object.
bool passesFilter(const CollisionObject* object) const;

bool mEnableAllHits;

bool mSortByClosest;

// TODO(JS): Add filter
/// Optional filter to reject hits from specific collision objects.
RaycastFilter mFilter;
};

} // namespace collision
Expand Down
49 changes: 43 additions & 6 deletions dart/collision/bullet/BulletCollisionDetector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,16 +377,50 @@ bool BulletCollisionDetector::raycast(
const auto btFrom = convertVector3(from);
const auto btTo = convertVector3(to);

if (option.mEnableAllHits) {
const bool needsAllHits
= option.mEnableAllHits || static_cast<bool>(option.mFilter);

if (needsAllHits) {
auto lessFraction = [](const RayHit& a, const RayHit& b) {
return a.mFraction < b.mFraction;
};

auto callback = btCollisionWorld::AllHitsRayResultCallback(btFrom, btTo);
castedGroup->updateEngineData();
collisionWorld->rayTest(btFrom, btTo, callback);

if (result == nullptr)
return callback.hasHit();
if (result == nullptr) {
if (!callback.hasHit())
return false;

if (!option.mFilter)
return true;

for (int i = 0; i < callback.m_collisionObjects.size(); ++i) {
const auto* collObj = static_cast<BulletCollisionObject*>(
callback.m_collisionObjects[i]->getUserPointer());
if (option.passesFilter(collObj))
return true;
}

return false;
}

if (callback.hasHit()) {
reportRayHits(callback, option, *result);

if (!option.mEnableAllHits && !result->mRayHits.empty()) {
if (option.mSortByClosest) {
result->mRayHits.resize(1);
} else {
const auto closest = std::min_element(
result->mRayHits.begin(), result->mRayHits.end(), lessFraction);
const RayHit closestHit = *closest;
result->mRayHits.clear();
result->mRayHits.emplace_back(closestHit);
}
}

return result->hasHit();
} else {
return false;
Expand Down Expand Up @@ -765,7 +799,7 @@ RayHit convertRayHit(
//==============================================================================
void reportRayHits(
const btCollisionWorld::ClosestRayResultCallback callback,
const RaycastOption& /*option*/,
const RaycastOption& option,
RaycastResult& result)
{
// This function shouldn't be called if callback has not ray hit.
Expand All @@ -779,7 +813,9 @@ void reportRayHits(

result.mRayHits.clear();
result.mRayHits.reserve(1);
result.mRayHits.emplace_back(rayHit);

if (option.passesFilter(rayHit.mCollisionObject))
result.mRayHits.emplace_back(rayHit);
}

//==============================================================================
Expand Down Expand Up @@ -807,7 +843,8 @@ void reportRayHits(
callback.m_hitPointWorld[i],
callback.m_hitNormalWorld[i],
callback.m_hitFractions[i]);
result.mRayHits.emplace_back(rayHit);
if (option.passesFilter(rayHit.mCollisionObject))
result.mRayHits.emplace_back(rayHit);
}

if (option.mSortByClosest)
Expand Down
8 changes: 7 additions & 1 deletion python/stubs/dartpy/collision.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,20 @@ class RayHit:
class RaycastOption:
mEnableAllHits: bool
mSortByClosest: bool
mFilter: typing.Optional[typing.Callable[[CollisionObject], bool]]
@typing.overload
def __init__(self) -> None:
...
@typing.overload
def __init__(self, enableAllHits: bool) -> None:
...
@typing.overload
def __init__(self, enableAllHits: bool, sortByClosest: bool) -> None:
def __init__(
self,
enableAllHits: bool,
sortByClosest: bool,
filter: typing.Optional[typing.Callable[[CollisionObject], bool]] = None,
) -> None:
...
class RaycastResult:
mRayHits: list[RayHit]
Expand Down
142 changes: 142 additions & 0 deletions tests/unit/collision/test_Raycast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ using namespace collision;
using namespace dynamics;
using namespace dart::test;

namespace {

class DummyCollisionObject : public collision::CollisionObject
{
public:
DummyCollisionObject(
collision::CollisionDetector* detector, const dynamics::ShapeFrame* frame)
: CollisionObject(detector, frame)
{
}

void updateEngineData() override {}
};

} // namespace

//==============================================================================
TEST(Raycast, RayHitDefaultConstructor)
{
Expand All @@ -62,6 +78,45 @@ TEST(Raycast, RaycastResultDefaultConstructor)
EXPECT_TRUE(result.mRayHits.empty());
}

//==============================================================================
TEST(Raycast, RaycastOptionPassesFilterWhenUnset)
{
auto detector = DARTCollisionDetector::create();

auto frame = SimpleFrame::createShared(Frame::World());
frame->setShape(std::make_shared<SphereShape>(0.1));
DummyCollisionObject object(detector.get(), frame.get());

RaycastOption option;

EXPECT_FALSE(option.mEnableAllHits);
EXPECT_FALSE(option.mSortByClosest);
EXPECT_TRUE(option.passesFilter(&object));
}

//==============================================================================
TEST(Raycast, RaycastOptionHonorsPredicate)
{
auto detector = DARTCollisionDetector::create();

auto firstFrame = SimpleFrame::createShared(Frame::World());
firstFrame->setShape(std::make_shared<SphereShape>(0.1));
auto secondFrame = SimpleFrame::createShared(Frame::World());
secondFrame->setShape(std::make_shared<SphereShape>(0.1));

DummyCollisionObject first(detector.get(), firstFrame.get());
DummyCollisionObject second(detector.get(), secondFrame.get());

RaycastOption option(true, true, [&](const collision::CollisionObject* obj) {
return obj == &first;
});

EXPECT_TRUE(option.mEnableAllHits);
EXPECT_TRUE(option.mSortByClosest);
EXPECT_TRUE(option.passesFilter(&first));
EXPECT_FALSE(option.passesFilter(&second));
}

//==============================================================================
void testBasicInterface(const std::shared_ptr<CollisionDetector>& cd)
{
Expand Down Expand Up @@ -242,3 +297,90 @@ TEST(Raycast, testOptions)
auto dart = DARTCollisionDetector::create();
testOptions(dart);
}

//==============================================================================
void testFilters(const std::shared_ptr<CollisionDetector>& cd)
{
if (cd->getType() != "bullet") {
DART_WARN(
"Aborting test: distance check is not supported by {}.", cd->getType());
return;
}

auto simpleFrame1 = SimpleFrame::createShared(Frame::World());
auto simpleFrame2 = SimpleFrame::createShared(Frame::World());

auto sphere = std::make_shared<SphereShape>(1.0);
simpleFrame1->setShape(sphere);
simpleFrame2->setShape(sphere);

auto group = cd->createCollisionGroup(simpleFrame1.get(), simpleFrame2.get());

collision::RaycastOption option;
collision::RaycastResult result;

simpleFrame1->setTranslation(Eigen::Vector3d(-2, 0, 0));
simpleFrame2->setTranslation(Eigen::Vector3d(2, 0, 0));

option.mEnableAllHits = false;
option.mSortByClosest = false;
option.mFilter = [&](const collision::CollisionObject* obj) {
return obj->getShapeFrame() == simpleFrame2.get();
};

result.clear();
cd->raycast(
group.get(),
Eigen::Vector3d(-5, 0, 0),
Eigen::Vector3d(5, 0, 0),
option,
&result);
ASSERT_TRUE(result.hasHit());
ASSERT_EQ(result.mRayHits.size(), 1u);
auto rayHit = result.mRayHits[0];
EXPECT_TRUE(equals(rayHit.mPoint, Eigen::Vector3d(1, 0, 0)));
EXPECT_TRUE(equals(rayHit.mNormal, Eigen::Vector3d(-1, 0, 0)));
EXPECT_NEAR(rayHit.mFraction, 0.6, 1e-5);

option.mEnableAllHits = true;
option.mFilter = [&](const collision::CollisionObject* obj) {
return obj->getShapeFrame() == simpleFrame1.get();
};
result.clear();
cd->raycast(
group.get(),
Eigen::Vector3d(-5, 0, 0),
Eigen::Vector3d(5, 0, 0),
option,
&result);
ASSERT_TRUE(result.hasHit());
ASSERT_EQ(result.mRayHits.size(), 1u);
rayHit = result.mRayHits[0];
EXPECT_TRUE(equals(rayHit.mPoint, Eigen::Vector3d(-3, 0, 0)));
EXPECT_TRUE(equals(rayHit.mNormal, Eigen::Vector3d(-1, 0, 0)));
EXPECT_NEAR(rayHit.mFraction, 0.2, 1e-5);

option.mFilter = [&](const collision::CollisionObject*) {
return false;
};
result.clear();
cd->raycast(
group.get(),
Eigen::Vector3d(-5, 0, 0),
Eigen::Vector3d(5, 0, 0),
option,
&result);
EXPECT_FALSE(result.hasHit());
EXPECT_TRUE(result.mRayHits.empty());
}

//==============================================================================
TEST(Raycast, testFilters)
{
#if HAVE_BULLET
auto bullet = BulletCollisionDetector::create();
testFilters(bullet);
#else
GTEST_SKIP() << "Bullet collision detector not available.";
#endif
}
Loading