Skip to content

Commit

Permalink
📦 NEW: Policy Temperature Root
Browse files Browse the repository at this point in the history
  • Loading branch information
ianagbip1oti committed Nov 10, 2023
1 parent 20dc367 commit 405a433
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 29 deletions.
11 changes: 11 additions & 0 deletions src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ static HASH_SIZE_MB: AtomicUsize = AtomicUsize::new(16);
static CPUCT: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(1.85));
static CVISITS_SELECTION: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(0.01));
static POLICY_TEMPERATURE: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(1.2));
static POLICY_TEMPERATURE_ROOT: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(3.5));

static CHESS960: AtomicBool = AtomicBool::new(false);

Expand Down Expand Up @@ -58,6 +59,16 @@ pub fn get_policy_temperature() -> f32 {
*pt
}

pub fn set_policy_temperature_root(t: f32) {
let mut pt = POLICY_TEMPERATURE_ROOT.write().unwrap();
*pt = t;
}

pub fn get_policy_temperature_root() -> f32 {
let pt = POLICY_TEMPERATURE_ROOT.read().unwrap();
*pt
}

pub fn set_chess960(c: bool) {
CHESS960.store(c, Ordering::Relaxed);
}
Expand Down
32 changes: 4 additions & 28 deletions src/search_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ use std::sync::atomic::{AtomicI64, AtomicPtr, AtomicU32, AtomicU64, AtomicUsize,

use crate::arena::Error as ArenaError;
use crate::evaluation::{self, Flag};
use crate::math;
use crate::mcts::{eval_in_cp, ThreadData};
use crate::options::{get_cpuct, get_cvisits_selection, get_policy_temperature};
use crate::options::{
get_cpuct, get_cvisits_selection, get_policy_temperature, get_policy_temperature_root,
};
use crate::search::{to_uci, TimeManagement, SCALE};
use crate::state::State;
use crate::transposition_table::{LRAllocator, LRTable, TranspositionTable};
Expand Down Expand Up @@ -62,10 +63,6 @@ impl SearchNode {
Self { hots, flag }
}

pub fn is_visited(&self) -> bool {
self.hots().iter().any(|x| x.visits() > 0)
}

pub fn flag(&self) -> Flag {
self.flag
}
Expand All @@ -86,15 +83,6 @@ impl SearchNode {
unsafe { &*(self.hots) }
}

fn update_policy(&mut self, evals: &[f32]) {
let hots = unsafe { &mut *(self.hots.cast_mut()) };

#[allow(clippy::cast_sign_loss)]
for i in 0..hots.len().min(evals.len()) {
hots[i].policy = (evals[i].clamp(0., 0.99) * SCALE) as u16;
}
}

pub fn clear_children_links(&self) {
let hots = unsafe { &*(self.hots.cast_mut()) };

Expand Down Expand Up @@ -201,24 +189,12 @@ impl SearchTree {
&state,
&tb_hits,
|sz| root_table.arena().allocator().alloc_slice(sz),
1.0,
get_policy_temperature_root(),
)
.expect("Unable to create root node");

previous_table.lookup_into(&state, &mut root_node);

if root_node.is_visited() {
let mut avg_rewards: Vec<f32> = root_node
.hots()
.iter()
.map(|m| m.average_reward().unwrap_or(-SCALE) / SCALE)
.collect();

math::softmax(&mut avg_rewards, 1.0);

root_node.update_policy(&avg_rewards);
}

Self {
root_state: state,
root_node,
Expand Down
4 changes: 3 additions & 1 deletion src/uci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::thread;

use crate::options::{
set_chess960, set_cpuct, set_cvisits_selection, set_hash_size_mb, set_num_threads,
set_policy_temperature,
set_policy_temperature, set_policy_temperature_root,
};
use crate::search::Search;
use crate::search_tree::print_size_list;
Expand Down Expand Up @@ -82,6 +82,7 @@ pub fn uci() {
println!("option name CPuct type string default 1.85");
println!("option name CVisitsSelection type string default 0.01");
println!("option name PolicyTemperature type string default 1.2");
println!("option name PolicyTemperatureRoot type string default 3.5");
println!("option name UCI_Chess960 type check default false");

println!("uciok");
Expand Down Expand Up @@ -143,6 +144,7 @@ impl UciOption {
"cpuct" => self.set_option(set_cpuct),
"cvisitsselection" => self.set_option(set_cvisits_selection),
"policytemperature" => self.set_option(set_policy_temperature),
"policytemperatureroot" => self.set_option(set_policy_temperature_root),
"uci_chess960" => self.set_option(set_chess960),
_ => warn!("Badly formatted or unknown option"),
}
Expand Down

0 comments on commit 405a433

Please sign in to comment.