Skip to content

Commit

Permalink
Fix nodes validation on rays change (#220)
Browse files Browse the repository at this point in the history
* Add test

* Clear getGraphRunCtx if rays modified

* Optimize

* Fix optimization

* Remove redundant fields existence checks

* Review changes
  • Loading branch information
msz-rai authored Nov 19, 2023
1 parent 2e91c81 commit 0eedfa6
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 26 deletions.
5 changes: 1 addition & 4 deletions extensions/pcl/src/graph/DownSamplePointsNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ using PCLPoint = pcl::PointXYZL;
void DownSamplePointsNode::validateImpl()
{
IPointsNodeSingleInput::validateImpl();
if (!input->hasField(XYZ_VEC3_F32)) {
auto msg = fmt::format("{} requires XYZ to be present", getName());
throw InvalidPipeline(msg);
}

// Needed to clear cache because fields in the pipeline may have changed
// In fact, the cache manager is no longer useful here
// To be kept/removed in some future refactor (when resolving comment in the `enqueueExecImpl`)
Expand Down
1 change: 0 additions & 1 deletion extensions/pcl/src/graph/NodesPcl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ struct VisualizePointsNode : IPointsNodeSingleInput
void setParameters(const char* windowName, int windowWidth, int windowHeight, bool fullscreen);

// Node
void validateImpl() override;
void enqueueExecImpl() override;

// Node requirements
Expand Down
9 changes: 0 additions & 9 deletions extensions/pcl/src/graph/VisualizePointsNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@ void VisualizePointsNode::setParameters(const char* windowName, int windowWidth,
visualizeThread->visualizeNodes.push_back(std::dynamic_pointer_cast<VisualizePointsNode>(shared_from_this()));
}

void VisualizePointsNode::validateImpl()
{
IPointsNodeSingleInput::validateImpl();
if (!input->hasField(XYZ_VEC3_F32)) {
auto msg = fmt::format("{} requires XYZ to be present", getName());
throw InvalidPipeline(msg);
}
}

// All calls to the viewers must be executed from the same thread
void VisualizePointsNode::VisualizeThread::runVisualize()
try {
Expand Down
21 changes: 13 additions & 8 deletions src/api/apiCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,25 @@ void createOrUpdateNode(rgl_node_t* nodeRawPtr, Args&&... args)
} else {
node = Node::validatePtr<NodeType>(*nodeRawPtr);
}
// TODO: The magic below detects calls changing rgl_field_t* (e.g. FormatPointsNode)
// TODO: Such changes may require recomputing required fields in RaytraceNode.
// TODO: However, taking care of this manually is very bug prone.
// TODO: There are other ways to automate this, however, for now this should be enough.
bool fieldsModified = ((std::is_same_v<Args, std::vector<rgl_field_t>> || ...));
if (fieldsModified && node->hasGraphRunCtx()) {
node->getGraphRunCtx()->detachAndDestroy();
}

// As of now, there's no guarantee that changing node parameter won't influence other nodes
// Therefore, before changing them, we need to ensure all nodes are idle (not running in GraphRunCtx).
if (node->hasGraphRunCtx()) {
node->getGraphRunCtx()->synchronize();
}

// TODO: The magic below detects calls changing rgl_field_t* (e.g. FormatPointsNode) or changing rays definition
// TODO: Such changes may require recomputing required fields in RaytraceNode
// TODO: or performing validation in nodes dependent on ray count (e.g. SetRingIdsRaysNode)
// TODO: However, taking care of this manually is very bug prone.
// TODO: There are other ways to automate this, however, for now this should be enough.
bool fieldsModified = (std::is_same_v<Args, std::vector<rgl_field_t>> || ...);
bool raysModified = std::is_same_v<NodeType, FromMat3x4fRaysNode>;
bool graphValidationNeeded = fieldsModified || raysModified;
if (graphValidationNeeded && node->hasGraphRunCtx()) {
node->getGraphRunCtx()->markNodesDirty();
}

node->setParameters(std::forward<Args>(args)...);
node->dirty = true;
*nodeRawPtr = node.get();
Expand Down
10 changes: 10 additions & 0 deletions src/graph/GraphRunCtx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ struct GraphRunCtx
*/
void synchronizeNodeCPU(Node::ConstPtr nodeToSynchronize);

/**
* Marks all nodes dirty.
*/
void markNodesDirty()
{
for (auto&& node : nodes) {
node->dirty = true;
}
}

bool isThisThreadGraphThread() const
{
return maybeThread.has_value() && maybeThread->get_id() == std::this_thread::get_id();
Expand Down
22 changes: 18 additions & 4 deletions test/src/graph/nodes/SetRingIdsRaysNodeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class SetRingIdsNodeTest : public RGLTestWithParam<int>
void initializeRingNodeAndIds(int idsCount)
{
setRingIdsNode = nullptr;
std::vector<int> ids(idsCount);
std::iota(ids.begin(), ids.end(), 0);
ringIds = ids;
ringIds.resize(idsCount);
std::iota(ringIds.begin(), ringIds.end(), 0);
}

void initializeRaysAndRaysNode(int rayCount)
{
rayNode = nullptr;
rays.clear();
rays.reserve(rayCount);
for (int i = 0; i < rayCount; i++) {
rays.emplace_back(
Expand Down Expand Up @@ -64,15 +64,29 @@ TEST_P(SetRingIdsNodeTest, invalid_pipeline_less_rays_than_ring_ids)
int32_t idsCount = GetParam();
if (idsCount / 2 == 0) {
return;
};
}

//// Incorrect number of ring ids passed to the rgl_node_rays_set_ring_ids ////
initializeRingNodeAndIds(idsCount);
initializeRaysAndRaysNode(idsCount / 2);
ASSERT_RGL_SUCCESS(rgl_node_rays_from_mat3x4f(&rayNode, rays.data(), rays.size()));
ASSERT_RGL_SUCCESS(rgl_node_rays_set_ring_ids(&setRingIdsNode, ringIds.data(), ringIds.size()));
ASSERT_RGL_SUCCESS(rgl_graph_node_add_child(rayNode, setRingIdsNode));

EXPECT_RGL_INVALID_PIPELINE(rgl_graph_run(setRingIdsNode), "ring ids doesn't match number of rays");

//// Changed number of rays between graph runs ////
// Initialize and run valid pipeline
initializeRingNodeAndIds(idsCount);
initializeRaysAndRaysNode(idsCount);
ASSERT_RGL_SUCCESS(rgl_node_rays_from_mat3x4f(&rayNode, rays.data(), rays.size()));
ASSERT_RGL_SUCCESS(rgl_node_rays_set_ring_ids(&setRingIdsNode, ringIds.data(), ringIds.size()));
ASSERT_RGL_SUCCESS(rgl_graph_node_add_child(rayNode, setRingIdsNode));
EXPECT_RGL_SUCCESS(rgl_graph_run(setRingIdsNode));

// Make pipeline invalid
ASSERT_RGL_SUCCESS(rgl_node_rays_from_mat3x4f(&rayNode, rays.data(), rays.size() / 2));
EXPECT_RGL_INVALID_PIPELINE(rgl_graph_run(setRingIdsNode), "ring ids doesn't match number of rays");
}

TEST_P(SetRingIdsNodeTest, valid_pipeline_equal_number_of_rays_and_ring_ids)
Expand Down

0 comments on commit 0eedfa6

Please sign in to comment.