Skip to content

Commit

Permalink
Fix remaining tests, rename device api args to clarify they take Index.
Browse files Browse the repository at this point in the history
  • Loading branch information
Robadob committed Oct 5, 2023
1 parent cc08a8a commit 3aecb7e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ class DeviceEnvironmentDirectedGraph {
* Constructor, takes the search parameters required
* Begin key and end key specify the [begin, end) contiguous range of bucket. (inclusive begin, exclusive end)
* @param _graph_hash Graph hash for accessing data via curve
* @param vertexID The index (ID in future?) of the vertex to retrieve leaving edges for
* @param vertex_index The index of the vertex to retrieve leaving edges for
*/
inline __device__ OutEdgeFilter(detail::curve::Curve::VariableHash _graph_hash, const id_t &vertexID);
inline __device__ OutEdgeFilter(detail::curve::Curve::VariableHash _graph_hash, const id_t &vertex_index);
/**
* Returns an iterator to the start of the message list subset about the search origin
*/
Expand Down Expand Up @@ -350,9 +350,9 @@ class DeviceEnvironmentDirectedGraph {
* Constructor, takes the search parameters required
* Begin key and end key specify the [begin, end) contiguous range of bucket. (inclusive begin, exclusive end)
* @param _graph_hash Graph hash for accessing data via curve
* @param vertexID The index (ID in future?) of the vertex to retrieve leaving edges for
* @param vertex_index The index of the vertex to retrieve leaving edges for
*/
inline __device__ InEdgeFilter(detail::curve::Curve::VariableHash _graph_hash, const id_t &vertexID);
inline __device__ InEdgeFilter(detail::curve::Curve::VariableHash _graph_hash, const id_t &vertex_index);
/**
* Returns an iterator to the start of the message list subset about the search origin
*/
Expand Down Expand Up @@ -407,57 +407,63 @@ class DeviceEnvironmentDirectedGraph {
template<typename T, flamegpu::size_type N, unsigned int M>
__device__ __forceinline__ T getEdgeProperty(const char(&property_name)[M], unsigned int edge_index, unsigned int element_index) const;

/**
* Returns a Filter object which provides access to an edge iterator
* for iterating a subset of edge which leave the specified vertex
*
* @param vertexID The index (ID in future?) of the vertex to retrieve leaving edges for
*/
inline __device__ OutEdgeFilter outEdges(const id_t & vertexID) const {
return OutEdgeFilter(graph_hash, vertexID);
/**
* Returns a Filter object which provides access to an edge iterator
* for iterating a subset of edge which leave the specified vertex
*
* @param vertex_index The index of the vertex to retrieve leaving edges for
*/
inline __device__ OutEdgeFilter outEdges(const id_t & vertex_index) const {
return OutEdgeFilter(graph_hash, vertex_index);
}
inline __device__ InEdgeFilter inEdges(const id_t& vertexID) const {
return InEdgeFilter(graph_hash, vertexID);
/**
* Returns a Filter object which provides access to an edge iterator
* for iterating a subset of edge which join the specified vertex
*
* @param vertex_index The index of the vertex to retrieve joining edges for
*/
inline __device__ InEdgeFilter inEdges(const id_t& vertex_index) const {
return InEdgeFilter(graph_hash, vertex_index);
}
};
__device__ DeviceEnvironmentDirectedGraph::OutEdgeFilter::OutEdgeFilter(const detail::curve::Curve::VariableHash _graph_hash, const id_t& vertexID)
__device__ DeviceEnvironmentDirectedGraph::OutEdgeFilter::OutEdgeFilter(const detail::curve::Curve::VariableHash _graph_hash, const id_t& vertex_index)
: bucket_begin(0)
, bucket_end(0)
, graph_hash(_graph_hash) {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
// Vertex "_id" always exists
const unsigned int VERTEX_COUNT = detail::curve::DeviceCurve::getVariableCount("_id", graph_hash ^ detail::curve::Curve::variableHash("_environment_directed_graph_vertex"));
if (vertexID >= VERTEX_COUNT) {
DTHROW("Vertex index (%u) exceeds vertex count (%u), unable to iterate outgoing edges.\n", vertexID, VERTEX_COUNT);
if (vertex_index >= VERTEX_COUNT) {
DTHROW("Vertex index (%u) exceeds vertex count (%u), unable to iterate outgoing edges.\n", vertex_index, VERTEX_COUNT);
return;
}
#endif
unsigned int* pbm = detail::curve::DeviceCurve::getEnvironmentDirectedGraphPBM(graph_hash);
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
if (!pbm) return;
#endif
bucket_begin = pbm[vertexID];
bucket_end = pbm[vertexID + 1];
bucket_begin = pbm[vertex_index];
bucket_end = pbm[vertex_index + 1];
}
__device__ DeviceEnvironmentDirectedGraph::InEdgeFilter::InEdgeFilter(const detail::curve::Curve::VariableHash _graph_hash, const id_t& vertexID)
__device__ DeviceEnvironmentDirectedGraph::InEdgeFilter::InEdgeFilter(const detail::curve::Curve::VariableHash _graph_hash, const id_t& vertex_index)
: bucket_begin(0)
, bucket_end(0)
, graph_ipbm_edges(nullptr)
, graph_hash(_graph_hash) {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
// Vertex "_id" always exists
const unsigned int VERTEX_COUNT = detail::curve::DeviceCurve::getVariableCount("_id", graph_hash ^ detail::curve::Curve::variableHash("_environment_directed_graph_vertex"));
if (vertexID >= VERTEX_COUNT) {
DTHROW("Vertex index (%u) exceeds vertex count (%u), unable to iterate incoming edges.\n", vertexID, VERTEX_COUNT);
if (vertex_index >= VERTEX_COUNT) {
DTHROW("Vertex index (%u) exceeds vertex count (%u), unable to iterate incoming edges.\n", vertex_index, VERTEX_COUNT);
return;
}
#endif
unsigned int* ipbm = detail::curve::DeviceCurve::getEnvironmentDirectedGraphIPBM(graph_hash);
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
if (!ipbm) return;
#endif
bucket_begin = ipbm[vertexID];
bucket_end = ipbm[vertexID + 1];
bucket_begin = ipbm[vertex_index];
bucket_end = ipbm[vertex_index + 1];
// Grab and store a copy of the PBM edgelist pointer
this->graph_ipbm_edges = detail::curve::DeviceCurve::getEnvironmentDirectedGraphIPBMEdges(_graph_hash);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ TEST(TestEnvironmentDirectedGraph, TestHostGetResetGet) {
// Check the data persists
model.newLayer().addHostFunction(HostCheckGraph);
// Init graph with different known data
//model.newLayer().addHostFunction(InitGraph3);
model.newLayer().addHostFunction(InitGraph3);
// Check the data persists
//model.newLayer().addHostFunction(HostCheckGraph3);
model.newLayer().addHostFunction(HostCheckGraph3);

CUDASimulation sim(model);

Expand Down Expand Up @@ -731,7 +731,7 @@ FLAMEGPU_AGENT_FUNCTION(IterateEdgesOut, MessageNone, MessageNone) {
auto filter = graph.outEdges(src);
FLAMEGPU->setVariable<int>("count2", filter.size());
for (auto &edge : filter) {
src_all_correct &= edge.getProperty<id_t>("src_copy") == src;
src_all_correct &= edge.getProperty<id_t>("src_copy") == graph.getVertexID(src);
FLAMEGPU->setVariable<id_t, 5>("dests", ct, graph.getVertexID(edge.getEdgeDestination()));
++ct;
}
Expand All @@ -756,11 +756,12 @@ FLAMEGPU_AGENT_FUNCTION(IterateEdgesIn, MessageNone, MessageNone) {
id_t dest = FLAMEGPU->getIndex();
unsigned int ct = 0;
bool dest_all_correct = true;
auto filter = FLAMEGPU->environment.getDirectedGraph("graph").inEdges(dest);
auto graph = FLAMEGPU->environment.getDirectedGraph("graph");
auto filter = graph.inEdges(dest);
FLAMEGPU->setVariable<int>("count2", filter.size());
for (auto& edge : filter) {
dest_all_correct &= edge.getProperty<id_t>("dest_copy") == dest;
FLAMEGPU->setVariable<id_t, 5>("srcs", ct, edge.getEdgeSource());
dest_all_correct &= edge.getProperty<id_t>("dest_copy") == graph.getVertexID(dest);
FLAMEGPU->setVariable<id_t, 5>("srcs", ct, graph.getVertexID(edge.getEdgeSource()));
++ct;
}
FLAMEGPU->setVariable<int>("count", ct);
Expand Down Expand Up @@ -858,7 +859,7 @@ TEST(TestEnvironmentDirectedGraph, TestEdgesOut) {
EXPECT_EQ(agt.getVariable<int>("count"), 5 - k);
EXPECT_EQ(agt.getVariable<int>("count2"), 5 - k);
for (int i = 0; i < 5 - k; ++i) {
EXPECT_EQ(agt.getVariable<id_t>("dests", i), static_cast<id_t>(k + i));
EXPECT_EQ(agt.getVariable<id_t>("dests", i), static_cast<id_t>(k + i) + 1);
}
++k;
}
Expand Down Expand Up @@ -895,7 +896,7 @@ TEST(TestEnvironmentDirectedGraph, TestEdgesIn) {
EXPECT_EQ(agt.getVariable<int>("count"), 5 - k);
EXPECT_EQ(agt.getVariable<int>("count2"), 5 - k);
for (int i = 0; i < 5 - k; ++i) {
EXPECT_EQ(agt.getVariable<id_t>("srcs", i), static_cast<id_t>(k + i));
EXPECT_EQ(agt.getVariable<id_t>("srcs", i), static_cast<id_t>(k + i) + 1);
}
++k;
}
Expand Down

0 comments on commit 3aecb7e

Please sign in to comment.