diff --git a/src/options.rs b/src/options.rs index bf3ae7a..de169c5 100644 --- a/src/options.rs +++ b/src/options.rs @@ -9,6 +9,7 @@ static HASH_SIZE_MB: AtomicUsize = AtomicUsize::new(16); static CPUCT: Lazy> = Lazy::new(|| RwLock::new(1.85)); static CVISITS_SELECTION: Lazy> = Lazy::new(|| RwLock::new(0.01)); static POLICY_TEMPERATURE: Lazy> = Lazy::new(|| RwLock::new(1.2)); +static POLICY_TEMPERATURE_ROOT: Lazy> = Lazy::new(|| RwLock::new(3.5)); static CHESS960: AtomicBool = AtomicBool::new(false); @@ -58,6 +59,16 @@ pub fn get_policy_temperature() -> f32 { *pt } +pub fn set_policy_temperature_root(t: f32) { + let mut pt = POLICY_TEMPERATURE_ROOT.write().unwrap(); + *pt = t; +} + +pub fn get_policy_temperature_root() -> f32 { + let pt = POLICY_TEMPERATURE_ROOT.read().unwrap(); + *pt +} + pub fn set_chess960(c: bool) { CHESS960.store(c, Ordering::Relaxed); } diff --git a/src/search_tree.rs b/src/search_tree.rs index e510fc8..4acfd8f 100644 --- a/src/search_tree.rs +++ b/src/search_tree.rs @@ -7,9 +7,10 @@ use std::sync::atomic::{AtomicI64, AtomicPtr, AtomicU32, AtomicU64, AtomicUsize, use crate::arena::Error as ArenaError; use crate::evaluation::{self, Flag}; -use crate::math; use crate::mcts::{eval_in_cp, ThreadData}; -use crate::options::{get_cpuct, get_cvisits_selection, get_policy_temperature}; +use crate::options::{ + get_cpuct, get_cvisits_selection, get_policy_temperature, get_policy_temperature_root, +}; use crate::search::{to_uci, TimeManagement, SCALE}; use crate::state::State; use crate::transposition_table::{LRAllocator, LRTable, TranspositionTable}; @@ -62,10 +63,6 @@ impl SearchNode { Self { hots, flag } } - pub fn is_visited(&self) -> bool { - self.hots().iter().any(|x| x.visits() > 0) - } - pub fn flag(&self) -> Flag { self.flag } @@ -86,15 +83,6 @@ impl SearchNode { unsafe { &*(self.hots) } } - fn update_policy(&mut self, evals: &[f32]) { - let hots = unsafe { &mut *(self.hots.cast_mut()) }; - - #[allow(clippy::cast_sign_loss)] - for i in 0..hots.len().min(evals.len()) { - hots[i].policy = (evals[i].clamp(0., 0.99) * SCALE) as u16; - } - } - pub fn clear_children_links(&self) { let hots = unsafe { &*(self.hots.cast_mut()) }; @@ -201,24 +189,12 @@ impl SearchTree { &state, &tb_hits, |sz| root_table.arena().allocator().alloc_slice(sz), - 1.0, + get_policy_temperature_root(), ) .expect("Unable to create root node"); previous_table.lookup_into(&state, &mut root_node); - if root_node.is_visited() { - let mut avg_rewards: Vec = root_node - .hots() - .iter() - .map(|m| m.average_reward().unwrap_or(-SCALE) / SCALE) - .collect(); - - math::softmax(&mut avg_rewards, 1.0); - - root_node.update_policy(&avg_rewards); - } - Self { root_state: state, root_node, diff --git a/src/uci.rs b/src/uci.rs index 0bae297..c8ce23c 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -5,7 +5,7 @@ use std::thread; use crate::options::{ set_chess960, set_cpuct, set_cvisits_selection, set_hash_size_mb, set_num_threads, - set_policy_temperature, + set_policy_temperature, set_policy_temperature_root, }; use crate::search::Search; use crate::search_tree::print_size_list; @@ -82,6 +82,7 @@ pub fn uci() { println!("option name CPuct type string default 1.85"); println!("option name CVisitsSelection type string default 0.01"); println!("option name PolicyTemperature type string default 1.2"); + println!("option name PolicyTemperatureRoot type string default 3.5"); println!("option name UCI_Chess960 type check default false"); println!("uciok"); @@ -143,6 +144,7 @@ impl UciOption { "cpuct" => self.set_option(set_cpuct), "cvisitsselection" => self.set_option(set_cvisits_selection), "policytemperature" => self.set_option(set_policy_temperature), + "policytemperatureroot" => self.set_option(set_policy_temperature_root), "uci_chess960" => self.set_option(set_chess960), _ => warn!("Badly formatted or unknown option"), }