diff --git a/cpp/src/wholememory_ops/functions/embedding_cache_func.cuh b/cpp/src/wholememory_ops/functions/embedding_cache_func.cuh index c46bf6e65..e3589285f 100644 --- a/cpp/src/wholememory_ops/functions/embedding_cache_func.cuh +++ b/cpp/src/wholememory_ops/functions/embedding_cache_func.cuh @@ -151,7 +151,6 @@ class CacheLineInfo { uint32_t lfu_count_; }; - template class CacheSetUpdater { public: @@ -160,13 +159,12 @@ class CacheSetUpdater { static constexpr int kScaledCounterBits = 14; private: - using warp_bq_t = raft::matrix::detail::select::warpsort::warp_sort_immediate; static constexpr int WARP_SIZE = 32; static constexpr int BLOCK_SIZE = kCacheSetSize; - static_assert(kCacheSetSize == WARP_SIZE,"only support CacheSetSize==32,and BLOCK_SIZE==32\n"); + static_assert(kCacheSetSize == WARP_SIZE, "only support CacheSetSize==32,and BLOCK_SIZE==32\n"); public: struct TempStorage { @@ -242,7 +240,7 @@ class CacheSetUpdater { // candidate_local_id_, // has_local_id_count); int64_t candidate_lfu_count0 = -1; - int candidate_local_id0 = -1; + int candidate_local_id0 = -1; unsigned int match_flag; // match_flag = WarpMatchLocalIDPairSync(candidate_local_id_[0], cached_local_id); int64_t estimated_lfu_count = cache_line_info.LfuCountSync(); @@ -348,15 +346,14 @@ class CacheSetUpdater { TempStorage& temp_storage, int cached_local_id) { - warp_bq_t warp_queue(kCacheSetSize); const int per_thread_lim = id_count + raft::laneId(); int has_local_id_count = 0; for (int idx = threadIdx.x; idx < per_thread_lim; idx += BLOCK_SIZE) { - int local_id = -1; + int local_id = -1; int64_t candidate_lfu_count = -1; - int candidate_local_id = -1; + int candidate_local_id = -1; if (idx < id_count) { local_id = gids != nullptr ? gids[idx] - cache_set_start_id : idx; candidate_lfu_count = cache_set_coverage_counter[local_id]; @@ -381,7 +378,6 @@ class CacheSetUpdater { } __syncthreads(); - return has_local_id_count; } }; diff --git a/cpp/tests/wholegraph_ops/wholegraph_csr_weighted_sample_without_replacement_tests.cu b/cpp/tests/wholegraph_ops/wholegraph_csr_weighted_sample_without_replacement_tests.cu index fa8cd4f10..eac1723af 100644 --- a/cpp/tests/wholegraph_ops/wholegraph_csr_weighted_sample_without_replacement_tests.cu +++ b/cpp/tests/wholegraph_ops/wholegraph_csr_weighted_sample_without_replacement_tests.cu @@ -446,7 +446,7 @@ INSTANTIATE_TEST_SUITE_P(WholeGraphCSRWeightedSampleWithoutReplacementOpTests, .set_center_node_count(35) .set_graph_node_count(23289) .set_graph_edge_couont(689403), - WholeGraphCSRWeightedSampleWithoutReplacementTestParam() + WholeGraphCSRWeightedSampleWithoutReplacementTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_max_sample_count(300) .set_center_node_count(256) diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py index 0f53044bd..138163a87 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py @@ -46,8 +46,7 @@ def host_weighted_sample_without_replacement_func( output_center_localid_tensor = torch.empty((total_sample_count,), dtype=torch.int32) output_edge_gid_tensor = torch.empty((total_sample_count,), dtype=torch.int64) center_nodes_count = center_nodes.size(0) - block_size = 128 if max_sample_count <=256 else 256 - + block_size = 128 if max_sample_count <= 256 else 256 for i in range(center_nodes_count): node_id = center_nodes[i] @@ -66,23 +65,25 @@ def host_weighted_sample_without_replacement_func( edge_weight_corresponding_ids = torch.tensor([], dtype=col_id_dtype) for j in range(block_size): local_gidx = gidx + j - local_edge_weights = torch.tensor( [],dtype=csr_weight_dtype - ) + local_edge_weights = torch.tensor([], dtype=csr_weight_dtype) generated_edge_weight_count = 0 - for id in range(j,neighbor_count,block_size): + for id in range(j, neighbor_count, block_size): local_edge_weights = torch.cat( - ( - local_edge_weights, - torch.tensor([host_csr_weight_ptr[start + id]], dtype=csr_weight_dtype), - ) + ( + local_edge_weights, + torch.tensor( + [host_csr_weight_ptr[start + id]], + dtype=csr_weight_dtype, + ), + ) ) generated_edge_weight_count += 1 edge_weight_corresponding_ids = torch.cat( - ( - edge_weight_corresponding_ids, - torch.tensor([id], dtype=col_id_dtype), - ) + ( + edge_weight_corresponding_ids, + torch.tensor([id], dtype=col_id_dtype), ) + ) random_values = ( wg_ops.generate_exponential_distribution_negative_float_cpu( random_seed, local_gidx, generated_edge_weight_count