Skip to content

Commit

Permalink
📦 NEW: Add policy temperature
Browse files Browse the repository at this point in the history
  • Loading branch information
ianagbip1oti committed Nov 9, 2023
1 parent 31a9fa6 commit 20dc367
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 18 deletions.
25 changes: 22 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,25 @@ services:
- ./builds/princhess-main:/engines/princhess-main
- ./syzygy:/syzygy:ro

sprt_gain_5k:
build:
context: .
dockerfile: bin/Dockerfile.cutechess-cli
command:
-engine cmd=/engines/princhess name=princhess
-engine cmd=/engines/princhess-main name=princhess-main

-each proto=uci tc=inf nodes=5000
option.SyzygyPath=/syzygy option.Hash=128 option.Threads=1
-sprt elo0=0 elo1=5 alpha=0.05 beta=0.1
-openings file=/books/4moves_noob.epd format=epd order=random
-games 2 -repeat -rounds 7500
-recover -ratinginterval 10 -concurrency 6
volumes:
- ./target/release/princhess:/engines/princhess
- ./builds/princhess-main:/engines/princhess-main
- ./syzygy:/syzygy:ro

sprt_gain_ltc:
build:
context: .
Expand Down Expand Up @@ -141,12 +160,12 @@ services:
-engine cmd=/engines/princhess-main name=princhess-main
-engine cmd=/engines/princhess name=princhess- option.

-each proto=uci tc=8+0.08
-each proto=uci tc=inf nodes=5000
option.SyzygyPath=/syzygy option.Hash=128 option.Threads=1
-tournament gauntlet
-openings file=/books/4moves_noob.epd format=epd order=random
-games 2 -repeat -rounds 50
-recover -ratinginterval 10 -concurrency 6
-games 2 -repeat -rounds 250
-recover -ratinginterval 100 -concurrency 6
volumes:
- ./target/release/princhess:/engines/princhess
- ./builds/princhess-main:/engines/princhess-main
Expand Down
8 changes: 4 additions & 4 deletions src/evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ pub fn evaluate_state_flag(state: &State, moves: &MoveList) -> Flag {
state.side_to_move().fold_wb(flag, flag.flip())
}

pub fn evaluate_policy(state: &State, moves: &MoveList) -> Vec<f32> {
run_policy_net(state, moves)
pub fn evaluate_policy(state: &State, moves: &MoveList, t: f32) -> Vec<f32> {
run_policy_net(state, moves, t)
}

const QAB: f32 = 256. * 256.;
Expand Down Expand Up @@ -123,7 +123,7 @@ fn run_eval_net(state: &State) -> f32 {
(result as f32 / QAB).tanh()
}

fn run_policy_net(state: &State, moves: &MoveList) -> Vec<f32> {
fn run_policy_net(state: &State, moves: &MoveList, t: f32) -> Vec<f32> {
let mut evalns = Vec::with_capacity(moves.len());

if moves.is_empty() {
Expand All @@ -143,7 +143,7 @@ fn run_policy_net(state: &State, moves: &MoveList) -> Vec<f32> {
}
});

math::softmax(&mut evalns);
math::softmax(&mut evalns, t);

evalns
}
4 changes: 2 additions & 2 deletions src/math.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub fn softmax(arr: &mut [f32]) {
pub fn softmax(arr: &mut [f32], t: f32) {
let max = max(arr);
let mut s = 0.;

for x in &mut *arr {
*x = fastapprox::faster::exp(*x - max);
*x = fastapprox::faster::exp((*x - max) / t);
s += *x;
}
for x in &mut *arr {
Expand Down
4 changes: 3 additions & 1 deletion src/mcts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use std::thread::JoinHandle;

use crate::evaluation;
use crate::options::get_policy_temperature;
use crate::search::{TimeManagement, SCALE};
pub use crate::search_tree::*;
use crate::state::State;
Expand Down Expand Up @@ -123,7 +124,8 @@ impl Mcts {
let root_moves = root_node.hots();

let state_moves = root_state.available_moves();
let state_moves_eval = evaluation::evaluate_policy(root_state, &state_moves);
let state_moves_eval =
evaluation::evaluate_policy(root_state, &state_moves, get_policy_temperature());

let mut moves: Vec<(&HotMoveInfo, f32)> = root_moves.iter().zip(state_moves_eval).collect();
moves.sort_by_key(|(h, e)| (h.average_reward().unwrap_or(*e) * SCALE) as i64);
Expand Down
11 changes: 11 additions & 0 deletions src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ static HASH_SIZE_MB: AtomicUsize = AtomicUsize::new(16);

static CPUCT: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(1.85));
static CVISITS_SELECTION: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(0.01));
static POLICY_TEMPERATURE: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(1.2));

static CHESS960: AtomicBool = AtomicBool::new(false);

Expand Down Expand Up @@ -47,6 +48,16 @@ pub fn get_cvisits_selection() -> f32 {
*cv
}

pub fn set_policy_temperature(t: f32) {
let mut pt = POLICY_TEMPERATURE.write().unwrap();
*pt = t;
}

pub fn get_policy_temperature() -> f32 {
let pt = POLICY_TEMPERATURE.read().unwrap();
*pt
}

pub fn set_chess960(c: bool) {
CHESS960.store(c, Ordering::Relaxed);
}
Expand Down
26 changes: 18 additions & 8 deletions src/search_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::arena::Error as ArenaError;
use crate::evaluation::{self, Flag};
use crate::math;
use crate::mcts::{eval_in_cp, ThreadData};
use crate::options::{get_cpuct, get_cvisits_selection};
use crate::options::{get_cpuct, get_cvisits_selection, get_policy_temperature};
use crate::search::{to_uci, TimeManagement, SCALE};
use crate::state::State;
use crate::transposition_table::{LRAllocator, LRTable, TranspositionTable};
Expand All @@ -26,6 +26,7 @@ pub struct SearchTree {
root_state: State,

cpuct: f32,
policy_t: f32,

#[allow(dead_code)]
root_table: TranspositionTable,
Expand Down Expand Up @@ -162,14 +163,15 @@ fn create_node<'a, F>(
state: &State,
tb_hits: &AtomicUsize,
alloc_slice: F,
policy_t: f32,
) -> Result<SearchNode, ArenaError>
where
F: FnOnce(usize) -> Result<&'a mut [HotMoveInfo], ArenaError>,
{
let moves = state.available_moves();

let state_flag = evaluation::evaluate_state_flag(state, &moves);
let move_eval = evaluation::evaluate_policy(state, &moves);
let move_eval = evaluation::evaluate_policy(state, &moves, policy_t);

if state_flag.is_tablebase() {
tb_hits.fetch_add(1, Ordering::Relaxed);
Expand All @@ -195,9 +197,12 @@ impl SearchTree {

let root_table = TranspositionTable::for_root();

let mut root_node = create_node(&state, &tb_hits, |sz| {
root_table.arena().allocator().alloc_slice(sz)
})
let mut root_node = create_node(
&state,
&tb_hits,
|sz| root_table.arena().allocator().alloc_slice(sz),
1.0,
)
.expect("Unable to create root node");

previous_table.lookup_into(&state, &mut root_node);
Expand All @@ -209,7 +214,7 @@ impl SearchTree {
.map(|m| m.average_reward().unwrap_or(-SCALE) / SCALE)
.collect();

math::softmax(&mut avg_rewards);
math::softmax(&mut avg_rewards, 1.0);

root_node.update_policy(&avg_rewards);
}
Expand All @@ -218,6 +223,7 @@ impl SearchTree {
root_state: state,
root_node,
cpuct: get_cpuct(),
policy_t: get_policy_temperature(),
root_table,
ttable: LRTable::new(current_table, previous_table),
num_nodes: 1.into(),
Expand Down Expand Up @@ -387,8 +393,12 @@ impl SearchTree {
};
}

let mut created_here =
create_node(state, &self.tb_hits, |sz| tld.allocator.alloc_move_info(sz))?;
let mut created_here = create_node(
state,
&self.tb_hits,
|sz| tld.allocator.alloc_move_info(sz),
self.policy_t,
)?;

self.ttable.lookup_into(state, &mut created_here);

Expand Down
3 changes: 3 additions & 0 deletions src/uci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::thread;

use crate::options::{
set_chess960, set_cpuct, set_cvisits_selection, set_hash_size_mb, set_num_threads,
set_policy_temperature,
};
use crate::search::Search;
use crate::search_tree::print_size_list;
Expand Down Expand Up @@ -80,6 +81,7 @@ pub fn uci() {
println!("option name SyzygyPath type string default <empty>");
println!("option name CPuct type string default 1.85");
println!("option name CVisitsSelection type string default 0.01");
println!("option name PolicyTemperature type string default 1.2");
println!("option name UCI_Chess960 type check default false");

println!("uciok");
Expand Down Expand Up @@ -140,6 +142,7 @@ impl UciOption {
"hash" => self.set_option(set_hash_size_mb),
"cpuct" => self.set_option(set_cpuct),
"cvisitsselection" => self.set_option(set_cvisits_selection),
"policytemperature" => self.set_option(set_policy_temperature),
"uci_chess960" => self.set_option(set_chess960),
_ => warn!("Badly formatted or unknown option"),
}
Expand Down

0 comments on commit 20dc367

Please sign in to comment.