From 7b675da9c47f4c5cca72c365c40e688113a18f3a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 20 Nov 2024 00:13:34 +0800 Subject: [PATCH] [bp] Fix rng for the column sampler. (#10998) (#11004) --- src/common/random.h | 2 +- src/tree/updater_gpu_hist.cu | 10 ++-------- tests/python/test_updaters.py | 22 ++++++++++++++++++++++ 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/common/random.h b/src/common/random.h index 6d7a1bb499c9..2dfa68e8b193 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -230,7 +230,7 @@ class ColumnSampler { }; inline auto MakeColumnSampler(Context const* ctx) { - std::uint32_t seed = common::GlobalRandomEngine()(); + std::uint32_t seed = common::GlobalRandom()(); auto rc = collective::Broadcast(ctx, linalg::MakeVec(&seed, 1), 0); collective::SafeColl(rc); auto cs = std::make_shared(seed); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index e126aeb313df..581801080b11 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -867,12 +867,7 @@ class GPUHistMaker : public TreeUpdater { CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device"; info_ = &dmat->Info(); - // Synchronise the column sampling seed - uint32_t column_sampling_seed = common::GlobalRandom()(); - auto rc = collective::Broadcast( - ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0); - SafeColl(rc); - this->column_sampler_ = std::make_shared(column_sampling_seed); + this->column_sampler_ = common::MakeColumnSampler(ctx_); auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()}; dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); @@ -1012,8 +1007,7 @@ class GPUGlobalApproxMaker : public TreeUpdater { monitor_.Start(__func__); CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal(); - uint32_t column_sampling_seed = common::GlobalRandom()(); - this->column_sampler_ = std::make_shared(column_sampling_seed); + this->column_sampler_ = common::MakeColumnSampler(ctx_); p_last_fmat_ = p_fmat; initialised_ = true; diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 8ec1fdd9d395..dee220f704c7 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -54,6 +54,28 @@ def test_exact_sample_by_node_error(self) -> None: num_boost_round=2, ) + @pytest.mark.parametrize("tree_method", ["approx", "hist"]) + def test_colsample_rng(self, tree_method: str) -> None: + """Test rng has an effect on column sampling.""" + X, y, _ = tm.make_regression(128, 16, use_cupy=False) + reg0 = xgb.XGBRegressor( + n_estimators=2, + colsample_bynode=0.5, + random_state=42, + tree_method=tree_method, + ) + reg0.fit(X, y) + + reg1 = xgb.XGBRegressor( + n_estimators=2, + colsample_bynode=0.5, + random_state=43, + tree_method=tree_method, + ) + reg1.fit(X, y) + + assert list(reg0.feature_importances_) != list(reg1.feature_importances_) + @given( exact_parameter_strategy, hist_parameter_strategy,