Skip to content

Commit

Permalink
Remove internal usage of multiple scenes (#213)
Browse files Browse the repository at this point in the history
* Remove internal usage of multiple scenes

* Review fixes
  • Loading branch information
prybicki authored Nov 10, 2023
1 parent ae14a87 commit eea3102
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 71 deletions.
4 changes: 2 additions & 2 deletions extensions/ros2/src/graph/Ros2PublishPointsNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ void Ros2PublishPointsNode::enqueueExecImpl()
ros2Message.row_step = ros2Message.point_step * ros2Message.width;
// TODO(msz-rai): Assign scene to the Graph.
// For now, only default scene is supported.
ros2Message.header.stamp = Scene::defaultInstance()->getTime().has_value() ?
Scene::defaultInstance()->getTime()->asRos2Msg() :
ros2Message.header.stamp = Scene::instance().getTime().has_value() ?
Scene::instance().getTime()->asRos2Msg() :
static_cast<builtin_interfaces::msg::Time>(ros2Node->get_clock()->now());
if (!rclcpp::ok()) {
throw std::runtime_error("Unable to publish a message because ROS2 has been shut down.");
Expand Down
47 changes: 15 additions & 32 deletions src/api/apiCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ RGL_API rgl_status_t rgl_cleanup(void)
Entity::instances.clear();
Mesh::instances.clear();
Texture::instances.clear();
Scene::defaultInstance()->clear();
Scene::instance().clear();
});
TAPE_HOOK();
return status;
Expand Down Expand Up @@ -256,11 +256,9 @@ RGL_API rgl_status_t rgl_entity_create(rgl_entity_t* out_entity, rgl_scene_t sce
RGL_API_LOG("rgl_entity_create(out_entity={}, scene={}, mesh={})", (void*) out_entity, (void*) scene, (void*) mesh);
CHECK_ARG(out_entity != nullptr);
CHECK_ARG(mesh != nullptr);
CHECK_ARG(scene == nullptr); // TODO: remove once rgl_scene_t param is removed
GraphRunCtx::synchronizeAll(); // Prevent races with graph threads
if (scene == nullptr) {
scene = Scene::defaultInstance().get();
}
*out_entity = Entity::create(Mesh::validatePtr(mesh), Scene::validatePtr(scene)).get();
*out_entity = Entity::create(Mesh::validatePtr(mesh)).get();
});
TAPE_HOOK(out_entity, scene, mesh);
return status;
Expand All @@ -269,9 +267,7 @@ RGL_API rgl_status_t rgl_entity_create(rgl_entity_t* out_entity, rgl_scene_t sce
void TapePlayer::tape_entity_create(const YAML::Node& yamlNode)
{
rgl_entity_t entity = nullptr;
rgl_entity_create(&entity,
nullptr, // TODO(msz-rai) support multiple scenes
tapeMeshes.at(yamlNode[2].as<TapeAPIObjectID>()));
rgl_entity_create(&entity, nullptr, tapeMeshes.at(yamlNode[2].as<TapeAPIObjectID>()));
tapeEntities.insert(std::make_pair(yamlNode[0].as<TapeAPIObjectID>(), entity));
}

Expand All @@ -282,7 +278,7 @@ RGL_API rgl_status_t rgl_entity_destroy(rgl_entity_t entity)
CHECK_ARG(entity != nullptr);
GraphRunCtx::synchronizeAll(); // Prevent races with graph threads
auto entitySafe = Entity::validatePtr(entity);
entitySafe->getScene()->removeEntity(entitySafe);
Scene::instance().removeEntity(entitySafe);
Entity::release(entity);
});
TAPE_HOOK(entity);
Expand Down Expand Up @@ -402,21 +398,16 @@ RGL_API rgl_status_t rgl_scene_set_time(rgl_scene_t scene, uint64_t nanoseconds)
{
auto status = rglSafeCall([&]() {
RGL_API_LOG("rgl_scene_set_time(scene={}, nanoseconds={})", (void*) scene, nanoseconds);
CHECK_ARG(scene == nullptr); // TODO: remove once rgl_scene_t param is removed
GraphRunCtx::synchronizeAll(); // Prevent races with graph threads
if (scene == nullptr) {
scene = Scene::defaultInstance().get();
}
Scene::validatePtr(scene)->setTime(Time::nanoseconds(nanoseconds));

Scene::instance().setTime(Time::nanoseconds(nanoseconds));
});
TAPE_HOOK(scene, nanoseconds);
return status;
}

void TapePlayer::tape_scene_set_time(const YAML::Node& yamlNode)
{
rgl_scene_set_time(nullptr, // TODO(msz-rai) support multiple scenes
yamlNode[1].as<uint64_t>());
}
void TapePlayer::tape_scene_set_time(const YAML::Node& yamlNode) { rgl_scene_set_time(nullptr, yamlNode[1].as<uint64_t>()); }

RGL_API rgl_status_t rgl_graph_run(rgl_node_t raw_node)
{
Expand Down Expand Up @@ -810,12 +801,9 @@ RGL_API rgl_status_t rgl_node_raytrace(rgl_node_t* node, rgl_scene_t scene)
auto status = rglSafeCall([&]() {
RGL_API_LOG("rgl_node_raytrace(node={}, scene={})", repr(node), (void*) scene);
CHECK_ARG(node != nullptr);
CHECK_ARG(scene == nullptr); // TODO: remove once rgl_scene_t param is removed

if (scene == nullptr) {
scene = Scene::defaultInstance().get();
}

createOrUpdateNode<RaytraceNode>(node, Scene::validatePtr(scene));
createOrUpdateNode<RaytraceNode>(node);
// Clear velocity that could be set by rgl_node_raytrace_with_distortion
Node::validatePtr<RaytraceNode>(*node)->setVelocity(nullptr, nullptr);
});
Expand All @@ -827,7 +815,7 @@ void TapePlayer::tape_node_raytrace(const YAML::Node& yamlNode)
{
auto nodeId = yamlNode[0].as<TapeAPIObjectID>();
rgl_node_t node = tapeNodes.contains(nodeId) ? tapeNodes.at(nodeId) : nullptr;
rgl_node_raytrace(&node, nullptr); // TODO(msz-rai) support multiple scenes
rgl_node_raytrace(&node, nullptr);
tapeNodes.insert({nodeId, node});
}

Expand All @@ -840,12 +828,9 @@ RGL_API rgl_status_t rgl_node_raytrace_with_distortion(rgl_node_t* node, rgl_sce
CHECK_ARG(node != nullptr);
CHECK_ARG(linear_velocity != nullptr);
CHECK_ARG(angular_velocity != nullptr);
CHECK_ARG(scene == nullptr);

if (scene == nullptr) {
scene = Scene::defaultInstance().get();
}

createOrUpdateNode<RaytraceNode>(node, Scene::validatePtr(scene));
createOrUpdateNode<RaytraceNode>(node);
Node::validatePtr<RaytraceNode>(*node)->setVelocity(reinterpret_cast<const Vec3f*>(linear_velocity),
reinterpret_cast<const Vec3f*>(angular_velocity));
});
Expand All @@ -857,9 +842,7 @@ void TapePlayer::tape_node_raytrace_with_distortion(const YAML::Node& yamlNode)
{
auto nodeId = yamlNode[0].as<TapeAPIObjectID>();
rgl_node_t node = tapeNodes.contains(nodeId) ? tapeNodes.at(nodeId) : nullptr;
rgl_node_raytrace_with_distortion(&node,
nullptr, // TODO(msz-rai) support multiple scenes
reinterpret_cast<const rgl_vec3f*>(fileMmap + yamlNode[2].as<size_t>()),
rgl_node_raytrace_with_distortion(&node, nullptr, reinterpret_cast<const rgl_vec3f*>(fileMmap + yamlNode[2].as<size_t>()),
reinterpret_cast<const rgl_vec3f*>(fileMmap + yamlNode[3].as<size_t>()));
tapeNodes.insert({nodeId, node});
}
Expand Down
3 changes: 1 addition & 2 deletions src/graph/NodesCore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct CompactPointsNode : IPointsNodeSingleInput
struct RaytraceNode : IPointsNode
{
using Ptr = std::shared_ptr<RaytraceNode>;
void setParameters(std::shared_ptr<Scene> scene);
void setParameters();

// Node
void validateImpl() override;
Expand All @@ -122,7 +122,6 @@ struct RaytraceNode : IPointsNode

private:
IRaysNode::Ptr raysNode;
std::shared_ptr<Scene> scene;

DeviceAsyncArray<Vec2f>::Ptr defaultRange = DeviceAsyncArray<Vec2f>::create(arrayMgr);
bool doApplyDistortion{false};
Expand Down
11 changes: 5 additions & 6 deletions src/graph/RaytraceNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
#include <macros/optix.hpp>
#include <RGLFields.hpp>

void RaytraceNode::setParameters(std::shared_ptr<Scene> scene)
void RaytraceNode::setParameters()
{
this->scene = scene;
const static Vec2f defaultRangeValue = Vec2f(0.0f, FLT_MAX);
defaultRange->copyFromExternal(&defaultRangeValue, 1);
}
Expand All @@ -39,7 +38,7 @@ void RaytraceNode::validateImpl()
throw InvalidPipeline(msg);
}

if (fieldData.contains(TIME_STAMP_F64) && !scene->getTime().has_value()) {
if (fieldData.contains(TIME_STAMP_F64) && !Scene::instance().getTime().has_value()) {
auto msg = fmt::format("requested for field TIME_STAMP_F64, but RaytraceNode cannot get time from scene");
throw InvalidPipeline(msg);
}
Expand Down Expand Up @@ -68,8 +67,8 @@ void RaytraceNode::enqueueExecImpl()

// Even though we are in graph thread here, we can access Scene class (see comment there)
const Mat3x4f* raysPtr = raysNode->getRays()->asSubclass<DeviceAsyncArray>()->getReadPtr();
auto sceneAS = scene->getASLocked();
auto sceneSBT = scene->getSBTLocked();
auto sceneAS = Scene::instance().getASLocked();
auto sceneSBT = Scene::instance().getSBTLocked();
dim3 launchDims = {static_cast<unsigned int>(raysNode->getRayCount()), 1, 1};

// Optional
Expand All @@ -94,7 +93,7 @@ void RaytraceNode::enqueueExecImpl()
.rayTimeOffsets = timeOffsets.has_value() ? (*timeOffsets)->asSubclass<DeviceAsyncArray>()->getReadPtr() : nullptr,
.rayTimeOffsetsCount = timeOffsets.has_value() ? (*timeOffsets)->getCount() : 0,
.scene = sceneAS,
.sceneTime = scene->getTime().has_value() ? scene->getTime()->asSeconds() : 0,
.sceneTime = Scene::instance().getTime().value_or(Time::zero()).asSeconds(),
.xyz = getPtrTo<XYZ_VEC3_F32>(),
.isHit = getPtrTo<IS_HIT_I32>(),
.rayIdx = getPtrTo<RAY_IDX_U32>(),
Expand Down
10 changes: 5 additions & 5 deletions src/scene/Entity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

API_OBJECT_INSTANCE(Entity);

std::shared_ptr<Entity> Entity::create(std::shared_ptr<Mesh> mesh, std::shared_ptr<Scene> scene)
std::shared_ptr<Entity> Entity::create(std::shared_ptr<Mesh> mesh)
{
auto entity = APIObject<Entity>::create(mesh);
scene->addEntity(entity);
Scene::instance().addEntity(entity);
return entity;
}

Expand All @@ -28,7 +28,7 @@ Entity::Entity(std::shared_ptr<Mesh> mesh) : mesh(std::move(mesh)) {}
void Entity::setTransform(Mat3x4f newTransform)
{
transform = newTransform;
scene->requestASRebuild();
Scene::instance().requestASRebuild();
}

void Entity::setId(int newId)
Expand All @@ -39,13 +39,13 @@ void Entity::setId(int newId)
throw std::invalid_argument(msg);
}
id = newId;
scene->requestASRebuild();
Scene::instance().requestASRebuild();
}

void Entity::setIntensityTexture(std::shared_ptr<Texture> texture)
{
intensityTexture = texture;
scene->requestSBTRebuild();
Scene::instance().requestSBTRebuild();
}

Mat3x4f Entity::getVelocity() const { throw std::runtime_error("unimplemented"); }
8 changes: 1 addition & 7 deletions src/scene/Entity.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct Entity : APIObject<Entity>
* Factory methods which creates an Entity and adds it to the given Scene.
* See constructor docs for more details.
*/
static std::shared_ptr<Entity> create(std::shared_ptr<Mesh> mesh, std::shared_ptr<Scene> scene);
static std::shared_ptr<Entity> create(std::shared_ptr<Mesh> mesh);

/**
* Sets ID that will be used as a point attribute ENTITY_ID_I32 when a ray hits this entity.
Expand All @@ -61,11 +61,6 @@ struct Entity : APIObject<Entity>
*/
Mat3x4f getVelocity() const;

/**
* @return The Scene in which this Entity is present.
*/
std::shared_ptr<Scene> getScene() const { return scene; }

private:
/**
* Creates Entity with given mesh and identity transform.
Expand All @@ -82,5 +77,4 @@ struct Entity : APIObject<Entity>

std::shared_ptr<Mesh> mesh{};
std::shared_ptr<Texture> intensityTexture{};
std::shared_ptr<Scene> scene{};
};
3 changes: 3 additions & 0 deletions src/scene/Mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <scene/Mesh.hpp>
#include <scene/Scene.hpp>

#include <filesystem>

Expand All @@ -35,6 +36,7 @@ void Mesh::updateVertices(const Vec3f* vertices, std::size_t vertexCount)
}
dVertices->copyFromExternal(vertices, vertexCount);
gasNeedsUpdate = true;
Scene::instance().requestASRebuild();
}

OptixTraversableHandle Mesh::getGAS(CudaStream::Ptr stream)
Expand Down Expand Up @@ -142,4 +144,5 @@ void Mesh::setTexCoords(const Vec2f* texCoords, std::size_t texCoordCount)

dTextureCoords.value()->copyFromExternal(texCoords, texCoordCount);
gasNeedsUpdate = true;
Scene::instance().requestSBTRebuild();
}
18 changes: 6 additions & 12 deletions src/scene/Scene.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
#include <scene/Texture.hpp>
#include <memory/Array.hpp>

API_OBJECT_INSTANCE(Scene);

std::shared_ptr<Scene> Scene::defaultInstance()
Scene& Scene::instance()
{
static auto scene = Scene::create();
static Scene scene;
return scene;
}

Expand All @@ -32,24 +30,20 @@ std::size_t Scene::getObjectCount() { return entities.size(); }
void Scene::clear()
{
entities.clear();
requestFullRebuild();
requestASRebuild();
requestSBTRebuild();
}

void Scene::addEntity(std::shared_ptr<Entity> entity)
{
entity->scene = shared_from_this();
entities.insert(entity);
requestFullRebuild();
requestASRebuild();
requestSBTRebuild();
}

void Scene::removeEntity(std::shared_ptr<Entity> entity)
{
entities.erase(entity);
requestFullRebuild();
}

void Scene::requestFullRebuild()
{
requestASRebuild();
requestSBTRebuild();
}
Expand Down
12 changes: 8 additions & 4 deletions src/scene/Scene.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@ struct Entity;
* The only case when graph thread accesses scene is getAS() and getSBT(), which are locked.
*
*/
struct Scene : APIObject<Scene>, std::enable_shared_from_this<Scene>
struct Scene
{
static std::shared_ptr<Scene> defaultInstance();
static Scene& instance();

Scene();
Scene(const Scene&) = delete;
Scene(Scene&&) = delete;
Scene& operator=(const Scene&) = delete;
Scene& operator=(Scene&&) = delete;

void addEntity(std::shared_ptr<Entity> entity);
void removeEntity(std::shared_ptr<Entity> entity);
Expand All @@ -63,11 +66,12 @@ struct Scene : APIObject<Scene>, std::enable_shared_from_this<Scene>
OptixTraversableHandle getASLocked();
OptixShaderBindingTable getSBTLocked();

void requestFullRebuild();
void requestASRebuild();
void requestSBTRebuild();

private:
Scene();

OptixShaderBindingTable buildSBT();
OptixTraversableHandle buildAS();

Expand Down
2 changes: 1 addition & 1 deletion test/src/apiSurfaceTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ TEST_F(APISurfaceTests, rgl_entity_create_destroy)
// Invalid args, note: scene can be nullptr here.
EXPECT_RGL_INVALID_ARGUMENT(rgl_entity_create(nullptr, nullptr, nullptr), "entity != nullptr");
EXPECT_RGL_INVALID_ARGUMENT(rgl_entity_create(&entity, nullptr, nullptr), "mesh != nullptr");
EXPECT_RGL_INVALID_OBJECT(rgl_entity_create(&entity, (rgl_scene_t) 0x1234, mesh), "Scene 0x1234");
EXPECT_RGL_INVALID_ARGUMENT(rgl_entity_create(&entity, (rgl_scene_t) 0x1234, mesh), "scene == nullptr");
EXPECT_RGL_INVALID_OBJECT(rgl_entity_create(&entity, nullptr, (rgl_mesh_t) 0x1234), "Mesh 0x1234");

// Correct create
Expand Down

0 comments on commit eea3102

Please sign in to comment.