Skip to content

Commit

Permalink
code style
Browse files Browse the repository at this point in the history
  • Loading branch information
chuangz0 committed Aug 15, 2023
1 parent 4be6fb9 commit 0955809
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 22 deletions.
12 changes: 4 additions & 8 deletions cpp/src/wholememory_ops/functions/embedding_cache_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ class CacheLineInfo {
uint32_t lfu_count_;
};


template <typename NodeIDT>
class CacheSetUpdater {
public:
Expand All @@ -160,13 +159,12 @@ class CacheSetUpdater {
static constexpr int kScaledCounterBits = 14;

private:

using warp_bq_t =
raft::matrix::detail::select::warpsort::warp_sort_immediate<kCacheSetSize, false, int64_t, int>;

static constexpr int WARP_SIZE = 32;
static constexpr int BLOCK_SIZE = kCacheSetSize;
static_assert(kCacheSetSize == WARP_SIZE,"only support CacheSetSize==32,and BLOCK_SIZE==32\n");
static_assert(kCacheSetSize == WARP_SIZE, "only support CacheSetSize==32,and BLOCK_SIZE==32\n");

public:
struct TempStorage {
Expand Down Expand Up @@ -242,7 +240,7 @@ class CacheSetUpdater {
// candidate_local_id_,
// has_local_id_count);
int64_t candidate_lfu_count0 = -1;
int candidate_local_id0 = -1;
int candidate_local_id0 = -1;
unsigned int match_flag;
// match_flag = WarpMatchLocalIDPairSync(candidate_local_id_[0], cached_local_id);
int64_t estimated_lfu_count = cache_line_info.LfuCountSync();
Expand Down Expand Up @@ -348,15 +346,14 @@ class CacheSetUpdater {
TempStorage& temp_storage,
int cached_local_id)
{

warp_bq_t warp_queue(kCacheSetSize);
const int per_thread_lim = id_count + raft::laneId();

int has_local_id_count = 0;
for (int idx = threadIdx.x; idx < per_thread_lim; idx += BLOCK_SIZE) {
int local_id = -1;
int local_id = -1;
int64_t candidate_lfu_count = -1;
int candidate_local_id = -1;
int candidate_local_id = -1;
if (idx < id_count) {
local_id = gids != nullptr ? gids[idx] - cache_set_start_id : idx;
candidate_lfu_count = cache_set_coverage_counter[local_id];
Expand All @@ -381,7 +378,6 @@ class CacheSetUpdater {
}
__syncthreads();


return has_local_id_count;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ INSTANTIATE_TEST_SUITE_P(WholeGraphCSRWeightedSampleWithoutReplacementOpTests,
.set_center_node_count(35)
.set_graph_node_count(23289)
.set_graph_edge_couont(689403),
WholeGraphCSRWeightedSampleWithoutReplacementTestParam()
WholeGraphCSRWeightedSampleWithoutReplacementTestParam()
.set_memory_type(WHOLEMEMORY_MT_CONTINUOUS)
.set_max_sample_count(300)
.set_center_node_count(256)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def host_weighted_sample_without_replacement_func(
output_center_localid_tensor = torch.empty((total_sample_count,), dtype=torch.int32)
output_edge_gid_tensor = torch.empty((total_sample_count,), dtype=torch.int64)
center_nodes_count = center_nodes.size(0)
block_size = 128 if max_sample_count <=256 else 256

block_size = 128 if max_sample_count <= 256 else 256

for i in range(center_nodes_count):
node_id = center_nodes[i]
Expand All @@ -66,23 +65,25 @@ def host_weighted_sample_without_replacement_func(
edge_weight_corresponding_ids = torch.tensor([], dtype=col_id_dtype)
for j in range(block_size):
local_gidx = gidx + j
local_edge_weights = torch.tensor( [],dtype=csr_weight_dtype
)
local_edge_weights = torch.tensor([], dtype=csr_weight_dtype)
generated_edge_weight_count = 0
for id in range(j,neighbor_count,block_size):
for id in range(j, neighbor_count, block_size):
local_edge_weights = torch.cat(
(
local_edge_weights,
torch.tensor([host_csr_weight_ptr[start + id]], dtype=csr_weight_dtype),
)
(
local_edge_weights,
torch.tensor(
[host_csr_weight_ptr[start + id]],
dtype=csr_weight_dtype,
),
)
)
generated_edge_weight_count += 1
edge_weight_corresponding_ids = torch.cat(
(
edge_weight_corresponding_ids,
torch.tensor([id], dtype=col_id_dtype),
)
(
edge_weight_corresponding_ids,
torch.tensor([id], dtype=col_id_dtype),
)
)
random_values = (
wg_ops.generate_exponential_distribution_negative_float_cpu(
random_seed, local_gidx, generated_edge_weight_count
Expand Down

0 comments on commit 0955809

Please sign in to comment.