Skip to content

Commit

Permalink
Support multi-raytrace graphs (#122)
Browse files Browse the repository at this point in the history
* Refactor findFieldToCompute and move to RaytraceNode

* Add test for spatial merge using multi-raytrace graph

* Extend format test with fields update

* Add comments suggested in review
  • Loading branch information
msz-rai committed Mar 29, 2023
1 parent b6c6965 commit 4c31872
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 50 deletions.
30 changes: 0 additions & 30 deletions src/graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include <graph/Graph.hpp>
#include <graph/NodesCore.hpp>

std::list<std::shared_ptr<Graph>> Graph::instances;

Expand Down Expand Up @@ -43,7 +42,6 @@ std::shared_ptr<Graph> Graph::create(std::shared_ptr<Node> node)

graph->nodes = Graph::findConnectedNodes(node);
graph->executionOrder = Graph::findExecutionOrder(graph->nodes);
graph->fieldsToCompute = Graph::findFieldsToCompute(graph->nodes);

for (auto&& currentNode : graph->nodes) {
if (currentNode->hasGraph()) {
Expand All @@ -64,12 +62,6 @@ void Graph::run()

RGL_DEBUG("Running graph with {} nodes", nodesInExecOrder.size());

// If Graph has RaytraceNode set fields to compute
if (!Node::filter<RaytraceNode>(nodesInExecOrder).empty()) {
RaytraceNode::Ptr rt = Node::getExactlyOne<RaytraceNode>(nodesInExecOrder);
rt->setFields(fieldsToCompute);
}

for (auto&& current : nodesInExecOrder) {
RGL_DEBUG("Validating node: {}", *current);
current->validate();
Expand Down Expand Up @@ -129,28 +121,6 @@ std::vector<std::shared_ptr<Node>> Graph::findExecutionOrder(std::set<std::share
return {reverseOrder.rbegin(), reverseOrder.rend()};
}

std::set<rgl_field_t> Graph::findFieldsToCompute(std::set<std::shared_ptr<Node>> nodes)
{
std::set<rgl_field_t> outFields;
for (auto&& node : nodes) {
if (auto pointNode = std::dynamic_pointer_cast<IPointsNode>(node)) {
for (auto&& field : pointNode->getRequiredFieldList()) {
if (!isDummy(field)) {
outFields.insert(field);
}
}
}
}

outFields.insert(XYZ_F32);

if (!Node::filter<CompactPointsNode>(nodes).empty()) {
outFields.insert(IS_HIT_I32);
}

return outFields;
}

Graph::~Graph()
{
stream.reset();
Expand Down
3 changes: 1 addition & 2 deletions src/graph/Graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>

#include <graph/Node.hpp>
#include <graph/NodesCore.hpp>
#include <CudaStream.hpp>

struct Graph
Expand All @@ -37,13 +38,11 @@ struct Graph
Graph() : stream(std::make_shared<CudaStream>()) {}

static std::vector<std::shared_ptr<Node>> findExecutionOrder(std::set<std::shared_ptr<Node>> nodes);
static std::set<rgl_field_t> findFieldsToCompute(std::set<std::shared_ptr<Node>> nodes);

private:
std::shared_ptr<CudaStream> stream;
std::set<std::shared_ptr<Node>> nodes;
std::vector<std::shared_ptr<Node>> executionOrder;
std::set<rgl_field_t> fieldsToCompute;

static std::list<std::shared_ptr<Graph>> instances;
};
8 changes: 6 additions & 2 deletions src/graph/NodesCore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ struct CompactPointsNode : Node, IPointsNodeSingleInput
void validate() override;
void schedule(cudaStream_t stream) override;

// Node requirements
std::vector<rgl_field_t> getRequiredFieldList() const override { return {IS_HIT_I32}; }

// Point cloud description
bool isDense() const override { return true; }
size_t getWidth() const override;
Expand Down Expand Up @@ -124,8 +127,6 @@ struct RaytraceNode : Node, IPointsNode
VArray::ConstPtr getFieldData(rgl_field_t field, cudaStream_t stream) const override
{ return std::const_pointer_cast<const VArray>(fieldData.at(field)); }


void setFields(const std::set<rgl_field_t>& fields);
private:
float range;
std::shared_ptr<Scene> scene;
Expand All @@ -135,6 +136,9 @@ struct RaytraceNode : Node, IPointsNode

template<rgl_field_t>
auto getPtrTo();

std::set<rgl_field_t> findFieldsToCompute();
void setFields(const std::set<rgl_field_t>& fields);
};

struct TransformPointsNode : Node, IPointsNodeSingleInput
Expand Down
31 changes: 31 additions & 0 deletions src/graph/RaytraceNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

void RaytraceNode::validate()
{
// It should be viewed as a temporary solution. Will change in v14.
setFields(findFieldsToCompute());

raysNode = getValidInput<IRaysNode>();

if (fieldData.contains(RING_ID_U16) && !raysNode->getRingIds().has_value()) {
Expand Down Expand Up @@ -93,3 +96,31 @@ void RaytraceNode::setFields(const std::set<rgl_field_t>& fields)
fieldData.insert({field, VArray::create(field)});
}
}

std::set<rgl_field_t> RaytraceNode::findFieldsToCompute()
{
std::set<rgl_field_t> outFields;

// Add primary field
outFields.insert(XYZ_F32);

// dfsInputs - if false dfs for outputs
std::function<void(Node::Ptr, bool)> dfsRet = [&](const Node::Ptr & current, bool dfsInputs) {
auto dfsNodes = dfsInputs ? current->getInputs() : current->getOutputs();
for (auto&& node : dfsNodes) {
if (auto pointNode = std::dynamic_pointer_cast<IPointsNode>(node)) {
for (auto&& field : pointNode->getRequiredFieldList()) {
if (!isDummy(field)) {
outFields.insert(field);
}
}
dfsRet(node, dfsInputs);
}
}
};

dfsRet(shared_from_this(), true); // Search in inputs. Needed for SetRingIds only.
dfsRet(shared_from_this(), false); // Search in outputs

return outFields;
}
108 changes: 92 additions & 16 deletions test/src/graphTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ TEST_F(GraphCase, NodeRemoval)
#endif
}

TEST_F(GraphCase, SpatialMerge)
TEST_F(GraphCase, SpatialMergeFromTransforms)
{
auto mesh = makeCubeMesh();

Expand Down Expand Up @@ -144,6 +144,51 @@ TEST_F(GraphCase, SpatialMerge)
#endif
}

TEST_F(GraphCase, SpatialMergeFromRaytraces)
{
// Setup cube scene
auto mesh = makeCubeMesh();
auto entity = makeEntity(mesh);
rgl_mat3x4f entityPoseTf = Mat3x4f::identity().toRGL();
ASSERT_RGL_SUCCESS(rgl_entity_set_pose(entity, &entityPoseTf));

constexpr int LIDAR_FOV_Y = 40;
constexpr int LIDAR_ROTATION_STEP = LIDAR_FOV_Y / 2; // Make laser overlaps to validate merging

std::vector<rgl_mat3x4f> rays = makeLidar3dRays(180, LIDAR_FOV_Y, 0.18, 1);

// Lidars will be located in the cube center with different rotations covering all the space.
std::vector<rgl_mat3x4f> lidarTfs;
for (int i = 0; i < 360 / LIDAR_ROTATION_STEP; ++i) {
lidarTfs.emplace_back(Mat3x4f::TRS({0, 0, 0}, {0, LIDAR_ROTATION_STEP * i, 0}).toRGL());
}

rgl_node_t spatialMerge=nullptr;
std::vector<rgl_field_t> sMergeFields = { RGL_FIELD_XYZ_F32, RGL_FIELD_DISTANCE_F32 };
EXPECT_RGL_SUCCESS(rgl_node_points_spatial_merge(&spatialMerge, sMergeFields.data(), sMergeFields.size()));

for (auto& lidarTf : lidarTfs) {
rgl_node_t lidarRays = nullptr;
rgl_node_t lidarRaysTf = nullptr;
rgl_node_t raytrace = nullptr;

EXPECT_RGL_SUCCESS(rgl_node_rays_from_mat3x4f(&lidarRays, rays.data(), rays.size()));
EXPECT_RGL_SUCCESS(rgl_node_rays_transform(&lidarRaysTf, &lidarTf));
EXPECT_RGL_SUCCESS(rgl_node_raytrace(&raytrace, nullptr, 1000));

EXPECT_RGL_SUCCESS(rgl_graph_node_add_child(lidarRays, lidarRaysTf));
EXPECT_RGL_SUCCESS(rgl_graph_node_add_child(lidarRaysTf, raytrace));
EXPECT_RGL_SUCCESS(rgl_graph_node_add_child(raytrace, spatialMerge));
}

EXPECT_RGL_SUCCESS(rgl_graph_run(spatialMerge));
#ifdef RGL_BUILD_PCL_EXTENSION
EXPECT_RGL_SUCCESS(rgl_graph_write_pcd_file(spatialMerge, "cube_spatial_merge.pcd"));
#else
RGL_WARN("RGL compiled without PCL extension. Tests will not save PCD!");
#endif
}

TEST_F(GraphCase, TemporalMerge)
{
auto mesh = makeCubeMesh();
Expand Down Expand Up @@ -197,6 +242,9 @@ TEST_F(GraphCase, FormatNodeResults)

rgl_node_t useRays=nullptr, raytrace=nullptr, lidarPose=nullptr, format=nullptr;

// The cube located in 0,0,0 with width equals 1, rays shoot in perpendicular direction
constexpr float EXPECTED_HITPOINT_Z = 1.0f;
constexpr float EXPECTED_RAY_DISTANCE = 1.0f;
std::vector<rgl_mat3x4f> rays = {
Mat3x4f::TRS({0, 0, 0}).toRGL(),
Mat3x4f::TRS({0.1, 0, 0}).toRGL(),
Expand All @@ -208,11 +256,17 @@ TEST_F(GraphCase, FormatNodeResults)
std::vector<rgl_field_t> formatFields = {
XYZ_F32,
PADDING_32,
TIME_STAMP_F64
TIME_STAMP_F64
};
struct FormatStruct
{
Field<XYZ_F32>::type xyz;
Field<PADDING_32>::type padding;
Field<TIME_STAMP_F64>::type timestamp;
} formatStruct;

Time timestamp = Time::seconds(1.5);
EXPECT_RGL_SUCCESS(rgl_scene_set_time(nullptr, timestamp.asNanoseconds()));
Time timestamp = Time::seconds(1.5);
EXPECT_RGL_SUCCESS(rgl_scene_set_time(nullptr, timestamp.asNanoseconds()));

EXPECT_RGL_SUCCESS(rgl_node_rays_from_mat3x4f(&useRays, rays.data(), rays.size()));
EXPECT_RGL_SUCCESS(rgl_node_rays_transform(&lidarPose, &lidarPoseTf));
Expand All @@ -228,24 +282,46 @@ TEST_F(GraphCase, FormatNodeResults)
int32_t outCount, outSizeOf;
EXPECT_RGL_SUCCESS(rgl_graph_get_result_size(format, RGL_FIELD_DYNAMIC_FORMAT, &outCount, &outSizeOf));

struct FormatStruct
{
Field<XYZ_F32>::type xyz;
Field<PADDING_32>::type padding;
Field<TIME_STAMP_F64>::type timestamp;
} formatStruct;

EXPECT_EQ(outCount, rays.size());
EXPECT_EQ(outSizeOf, sizeof(formatStruct));

std::vector<FormatStruct> formatData{(size_t)outCount};
std::vector<FormatStruct> formatData(outCount);
EXPECT_RGL_SUCCESS(rgl_graph_get_result_data(format, RGL_FIELD_DYNAMIC_FORMAT, formatData.data()));

for (int i = 0; i < formatData.size(); ++i) {
EXPECT_NEAR(formatData[i].xyz[0], rays[i].value[0][3], 1e-6);
EXPECT_NEAR(formatData[i].xyz[1], rays[i].value[1][3], 1e-6);
EXPECT_NEAR(formatData[i].xyz[2], 1, 1e-6);
EXPECT_EQ(formatData[i].timestamp, timestamp.asSeconds());
EXPECT_NEAR(formatData[i].xyz[0], rays[i].value[0][3], EPSILON_F);
EXPECT_NEAR(formatData[i].xyz[1], rays[i].value[1][3], EPSILON_F);
EXPECT_NEAR(formatData[i].xyz[2], EXPECTED_HITPOINT_Z, EPSILON_F);
EXPECT_EQ(formatData[i].timestamp, timestamp.asSeconds());
}

// Test if fields update is propagated over graph properly
formatFields.push_back(DISTANCE_F32); // Add distance field
formatFields.push_back(PADDING_32); // Align to 8 bytes
struct FormatStructExtended : public FormatStruct
{
Field<DISTANCE_F32>::type distance;
Field<PADDING_32>::type padding2; // Align to 8 bytes
} formatStructEx;

EXPECT_RGL_SUCCESS(rgl_node_points_format(&format, formatFields.data(), formatFields.size()));
EXPECT_RGL_SUCCESS(rgl_graph_run(raytrace));

outCount = -1; // reset variables
outSizeOf = -1;
EXPECT_RGL_SUCCESS(rgl_graph_get_result_size(format, RGL_FIELD_DYNAMIC_FORMAT, &outCount, &outSizeOf));
EXPECT_EQ(outCount, rays.size());
EXPECT_EQ(outSizeOf, sizeof(formatStructEx));

std::vector<FormatStructExtended> formatDataEx{(size_t)outCount};
EXPECT_RGL_SUCCESS(rgl_graph_get_result_data(format, RGL_FIELD_DYNAMIC_FORMAT, formatDataEx.data()));

for (int i = 0; i < formatDataEx.size(); ++i) {
EXPECT_NEAR(formatDataEx[i].xyz[0], rays[i].value[0][3], EPSILON_F);
EXPECT_NEAR(formatDataEx[i].xyz[1], rays[i].value[1][3], EPSILON_F);
EXPECT_NEAR(formatDataEx[i].xyz[2], EXPECTED_HITPOINT_Z, EPSILON_F);
EXPECT_NEAR(formatDataEx[i].distance, EXPECTED_RAY_DISTANCE, EPSILON_F);
EXPECT_EQ(formatDataEx[i].timestamp, timestamp.asSeconds());
}
}

Expand Down

0 comments on commit 4c31872

Please sign in to comment.