Skip to content

Commit

Permalink
📦 NEW: Add MultiPV support
Browse files Browse the repository at this point in the history
  • Loading branch information
ianagbip1oti committed Jun 27, 2024
1 parent 8f6fa71 commit 2525b4a
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 63 deletions.
9 changes: 9 additions & 0 deletions src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::RwLock;

static NUM_THREADS: AtomicUsize = AtomicUsize::new(1);
static HASH_SIZE_MB: AtomicUsize = AtomicUsize::new(16);
static MULTI_PV: AtomicUsize = AtomicUsize::new(1);

static CPUCT: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(1.06));
static CPUCT_ROOT: Lazy<RwLock<f32>> = Lazy::new(|| RwLock::new(3.17));
Expand All @@ -30,6 +31,14 @@ pub fn get_hash_size_mb() -> usize {
max(1, HASH_SIZE_MB.load(Ordering::Relaxed))
}

pub fn set_multi_pv(pv: usize) {
MULTI_PV.store(pv, Ordering::Relaxed);
}

pub fn get_multi_pv() -> usize {
max(1, MULTI_PV.load(Ordering::Relaxed))
}

pub fn set_cpuct(c: f32) {
let mut cp = CPUCT.write().unwrap();
*cp = c;
Expand Down
11 changes: 1 addition & 10 deletions src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,8 @@ impl Search {
}
}

pub fn principal_variation(&self, num_moves: usize) -> Vec<chess::Move> {
self.search_tree
.principal_variation(num_moves)
.into_iter()
.map(MoveEdge::get_move)
.copied()
.collect()
}

pub fn best_move(&self) -> chess::Move {
*self.principal_variation(1).get(0).unwrap()
self.search_tree.best_move()
}
}

Expand Down
126 changes: 74 additions & 52 deletions src/search_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::arena::Error as ArenaError;
use crate::chess;
use crate::evaluation::{self, Flag};
use crate::options::{
get_cpuct, get_cpuct_root, get_cvisits_selection, get_policy_temperature,
get_cpuct, get_cpuct_root, get_cvisits_selection, get_multi_pv, get_policy_temperature,
get_policy_temperature_root,
};
use crate::search::{eval_in_cp, ThreadData};
Expand Down Expand Up @@ -93,6 +93,24 @@ impl PositionNode {
h.child.store(null_mut(), Ordering::SeqCst);
}
}

pub fn select_child_by_rewards(&self) -> &MoveEdge {
let children = self.hots();

let mut best = &children[0];
let mut best_reward = best.average_reward().unwrap_or(-SCALE);

for child in children.iter().skip(1) {
let reward = child.average_reward().unwrap_or(-SCALE);

if reward > best_reward {
best = child;
best_reward = reward;
}
}

best
}
}

impl MoveEdge {
Expand Down Expand Up @@ -431,21 +449,9 @@ impl SearchTree {
&self.root_node
}

pub fn principal_variation(&self, num_moves: usize) -> Vec<&MoveEdge> {
let mut result = Vec::new();
let mut crnt = &self.root_node;
while !crnt.hots().is_empty() && result.len() < num_moves {
let choice = select_child_after_search(crnt.hots());
result.push(choice);
let child = choice.child.load(Ordering::SeqCst).cast_const();
if child.is_null() {
break;
}
unsafe {
crnt = &*child;
}
}
result
pub fn best_move(&self) -> chess::Move {
let best = sort_moves(&self.root_node.hots())[0];
*best.get_move()
}

pub fn print_info(&self, time_management: &TimeManagement) {
Expand All @@ -454,70 +460,86 @@ impl SearchTree {
let nodes = self.num_nodes();
let depth = nodes / self.playouts();
let sel_depth = self.max_depth();
let pv = self.principal_variation(depth.max(1));
let pv_string: String = pv.into_iter().fold(String::new(), |mut out, x| {
write!(out, " {}", x.get_move().to_uci()).unwrap();
out
});

let nps = if search_time_ms == 0 {
nodes
} else {
nodes * 1000 / search_time_ms as usize
};

let info_str = format!(
"info depth {} seldepth {} nodes {} nps {} tbhits {} score {} time {} pv{}",
depth.max(1),
sel_depth.max(1),
nodes,
nps,
self.tb_hits(),
self.eval_in_cp(),
search_time_ms,
pv_string,
);
println!("{info_str}");
}
let moves = sort_moves(&self.root_node.hots());

for (idx, edge) in moves.iter().enumerate().take(get_multi_pv()) {
let pv = match edge.child() {
Some(child) => principal_variation(child, depth.max(1) - 1),
None => vec![],
};

pub fn eval(&self) -> f32 {
self.principal_variation(1)
.get(0)
.map_or(0., |x| x.average_reward().unwrap_or(-SCALE) / SCALE)
let pv_string: String = pv.into_iter().fold(edge.get_move().to_uci(), |mut out, x| {
write!(out, " {}", x.get_move().to_uci()).unwrap();
out
});

let eval = eval_in_cp(edge.average_reward().unwrap_or(-SCALE) / SCALE);

let info_str = format!(
"info depth {} seldepth {} nodes {} nps {} tbhits {} score {} time {} multipv {} pv {}",
depth.max(1),
sel_depth.max(1),
nodes,
nps,
self.tb_hits(),
eval,
search_time_ms,
idx + 1,
pv_string,
);
println!("{info_str}");
}
}
}

fn eval_in_cp(&self) -> String {
eval_in_cp(self.eval())
fn principal_variation(from: &PositionNode, num_moves: usize) -> Vec<&MoveEdge> {
let mut result = Vec::with_capacity(num_moves);
let mut crnt = from;

while !crnt.hots().is_empty() && result.len() < num_moves {
let choice = crnt.select_child_by_rewards();
result.push(choice);

match choice.child() {
Some(child) => crnt = child,
None => break,
}
}

result
}

fn select_child_after_search(children: &[MoveEdge]) -> &MoveEdge {
fn sort_moves(children: &[MoveEdge]) -> Vec<&MoveEdge> {
let k = get_cvisits_selection();

let reward = |child: &MoveEdge| {
let visits = child.visits();

if visits == 0 {
return -SCALE;
return -(2. * SCALE) + f32::from(child.policy());
}

let sum_rewards = child.sum_rewards();

sum_rewards as f32 / visits as f32 - (k * 2. * SCALE) / (visits as f32).sqrt()
};

let mut best = &children[0];
let mut best_reward = reward(best);
let mut result = Vec::with_capacity(children.len());

for child in children.iter().skip(1) {
let reward = reward(child);
if reward > best_reward {
best = child;
best_reward = reward;
}
for child in children {
result.push((child, reward(child)));
}

best
result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
result.reverse();

result.into_iter().map(|x| x.0).collect()
}

pub fn print_size_list() {
Expand Down
4 changes: 3 additions & 1 deletion src/uci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io::stdin;
use std::str::{FromStr, SplitWhitespace};

use crate::options::{
set_chess960, set_cpuct, set_cpuct_root, set_cvisits_selection, set_hash_size_mb,
set_chess960, set_cpuct, set_cpuct_root, set_cvisits_selection, set_hash_size_mb, set_multi_pv,
set_num_threads, set_policy_temperature, set_policy_temperature_root,
};
use crate::search::Search;
Expand Down Expand Up @@ -74,6 +74,7 @@ pub fn uci() {
println!("id author {ENGINE_AUTHOR}");
println!("option name Hash type spin min 8 max 65536 default 16");
println!("option name Threads type spin min 1 max 255 default 1");
println!("option name MultiPV type spin min 1 max 255 default 1");
println!("option name SyzygyPath type string default <empty>");
println!("option name CPuct type string default 1.06");
println!("option name CPuctRoot type string default 3.17");
Expand Down Expand Up @@ -141,6 +142,7 @@ impl UciOption {
self.set_option(set_hash_size_mb);
search.reset_table();
}
"multipv" => self.set_option(set_multi_pv),
"cpuct" => self.set_option(set_cpuct),
"cpuctroot" => self.set_option(set_cpuct_root),
"cvisitsselection" => self.set_option(set_cvisits_selection),
Expand Down

0 comments on commit 2525b4a

Please sign in to comment.