From 2525b4a93a8f75faa4c167c93be15cf15cd40c7f Mon Sep 17 00:00:00 2001 From: ianagbip1oti Date: Thu, 27 Jun 2024 06:10:10 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=A6=20NEW:=20Add=20MultiPV=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/options.rs | 9 ++++ src/search.rs | 11 +--- src/search_tree.rs | 126 ++++++++++++++++++++++++++------------------- src/uci.rs | 4 +- 4 files changed, 87 insertions(+), 63 deletions(-) diff --git a/src/options.rs b/src/options.rs index 6643e93..36343bf 100644 --- a/src/options.rs +++ b/src/options.rs @@ -5,6 +5,7 @@ use std::sync::RwLock; static NUM_THREADS: AtomicUsize = AtomicUsize::new(1); static HASH_SIZE_MB: AtomicUsize = AtomicUsize::new(16); +static MULTI_PV: AtomicUsize = AtomicUsize::new(1); static CPUCT: Lazy> = Lazy::new(|| RwLock::new(1.06)); static CPUCT_ROOT: Lazy> = Lazy::new(|| RwLock::new(3.17)); @@ -30,6 +31,14 @@ pub fn get_hash_size_mb() -> usize { max(1, HASH_SIZE_MB.load(Ordering::Relaxed)) } +pub fn set_multi_pv(pv: usize) { + MULTI_PV.store(pv, Ordering::Relaxed); +} + +pub fn get_multi_pv() -> usize { + max(1, MULTI_PV.load(Ordering::Relaxed)) +} + pub fn set_cpuct(c: f32) { let mut cp = CPUCT.write().unwrap(); *cp = c; diff --git a/src/search.rs b/src/search.rs index d43873a..08a2432 100644 --- a/src/search.rs +++ b/src/search.rs @@ -294,17 +294,8 @@ impl Search { } } - pub fn principal_variation(&self, num_moves: usize) -> Vec { - self.search_tree - .principal_variation(num_moves) - .into_iter() - .map(MoveEdge::get_move) - .copied() - .collect() - } - pub fn best_move(&self) -> chess::Move { - *self.principal_variation(1).get(0).unwrap() + self.search_tree.best_move() } } diff --git a/src/search_tree.rs b/src/search_tree.rs index f471e56..99eaa84 100644 --- a/src/search_tree.rs +++ b/src/search_tree.rs @@ -11,7 +11,7 @@ use crate::arena::Error as ArenaError; use crate::chess; use crate::evaluation::{self, Flag}; use crate::options::{ - get_cpuct, get_cpuct_root, get_cvisits_selection, get_policy_temperature, + get_cpuct, get_cpuct_root, get_cvisits_selection, get_multi_pv, get_policy_temperature, get_policy_temperature_root, }; use crate::search::{eval_in_cp, ThreadData}; @@ -93,6 +93,24 @@ impl PositionNode { h.child.store(null_mut(), Ordering::SeqCst); } } + + pub fn select_child_by_rewards(&self) -> &MoveEdge { + let children = self.hots(); + + let mut best = &children[0]; + let mut best_reward = best.average_reward().unwrap_or(-SCALE); + + for child in children.iter().skip(1) { + let reward = child.average_reward().unwrap_or(-SCALE); + + if reward > best_reward { + best = child; + best_reward = reward; + } + } + + best + } } impl MoveEdge { @@ -431,21 +449,9 @@ impl SearchTree { &self.root_node } - pub fn principal_variation(&self, num_moves: usize) -> Vec<&MoveEdge> { - let mut result = Vec::new(); - let mut crnt = &self.root_node; - while !crnt.hots().is_empty() && result.len() < num_moves { - let choice = select_child_after_search(crnt.hots()); - result.push(choice); - let child = choice.child.load(Ordering::SeqCst).cast_const(); - if child.is_null() { - break; - } - unsafe { - crnt = &*child; - } - } - result + pub fn best_move(&self) -> chess::Move { + let best = sort_moves(&self.root_node.hots())[0]; + *best.get_move() } pub fn print_info(&self, time_management: &TimeManagement) { @@ -454,51 +460,69 @@ impl SearchTree { let nodes = self.num_nodes(); let depth = nodes / self.playouts(); let sel_depth = self.max_depth(); - let pv = self.principal_variation(depth.max(1)); - let pv_string: String = pv.into_iter().fold(String::new(), |mut out, x| { - write!(out, " {}", x.get_move().to_uci()).unwrap(); - out - }); - let nps = if search_time_ms == 0 { nodes } else { nodes * 1000 / search_time_ms as usize }; - let info_str = format!( - "info depth {} seldepth {} nodes {} nps {} tbhits {} score {} time {} pv{}", - depth.max(1), - sel_depth.max(1), - nodes, - nps, - self.tb_hits(), - self.eval_in_cp(), - search_time_ms, - pv_string, - ); - println!("{info_str}"); - } + let moves = sort_moves(&self.root_node.hots()); + + for (idx, edge) in moves.iter().enumerate().take(get_multi_pv()) { + let pv = match edge.child() { + Some(child) => principal_variation(child, depth.max(1) - 1), + None => vec![], + }; - pub fn eval(&self) -> f32 { - self.principal_variation(1) - .get(0) - .map_or(0., |x| x.average_reward().unwrap_or(-SCALE) / SCALE) + let pv_string: String = pv.into_iter().fold(edge.get_move().to_uci(), |mut out, x| { + write!(out, " {}", x.get_move().to_uci()).unwrap(); + out + }); + + let eval = eval_in_cp(edge.average_reward().unwrap_or(-SCALE) / SCALE); + + let info_str = format!( + "info depth {} seldepth {} nodes {} nps {} tbhits {} score {} time {} multipv {} pv {}", + depth.max(1), + sel_depth.max(1), + nodes, + nps, + self.tb_hits(), + eval, + search_time_ms, + idx + 1, + pv_string, + ); + println!("{info_str}"); + } } +} - fn eval_in_cp(&self) -> String { - eval_in_cp(self.eval()) +fn principal_variation(from: &PositionNode, num_moves: usize) -> Vec<&MoveEdge> { + let mut result = Vec::with_capacity(num_moves); + let mut crnt = from; + + while !crnt.hots().is_empty() && result.len() < num_moves { + let choice = crnt.select_child_by_rewards(); + result.push(choice); + + match choice.child() { + Some(child) => crnt = child, + None => break, + } } + + result } -fn select_child_after_search(children: &[MoveEdge]) -> &MoveEdge { +fn sort_moves(children: &[MoveEdge]) -> Vec<&MoveEdge> { let k = get_cvisits_selection(); let reward = |child: &MoveEdge| { let visits = child.visits(); if visits == 0 { - return -SCALE; + return -(2. * SCALE) + f32::from(child.policy()); } let sum_rewards = child.sum_rewards(); @@ -506,18 +530,16 @@ fn select_child_after_search(children: &[MoveEdge]) -> &MoveEdge { sum_rewards as f32 / visits as f32 - (k * 2. * SCALE) / (visits as f32).sqrt() }; - let mut best = &children[0]; - let mut best_reward = reward(best); + let mut result = Vec::with_capacity(children.len()); - for child in children.iter().skip(1) { - let reward = reward(child); - if reward > best_reward { - best = child; - best_reward = reward; - } + for child in children { + result.push((child, reward(child))); } - best + result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + result.reverse(); + + result.into_iter().map(|x| x.0).collect() } pub fn print_size_list() { diff --git a/src/uci.rs b/src/uci.rs index acb598e..bb740ed 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -2,7 +2,7 @@ use std::io::stdin; use std::str::{FromStr, SplitWhitespace}; use crate::options::{ - set_chess960, set_cpuct, set_cpuct_root, set_cvisits_selection, set_hash_size_mb, + set_chess960, set_cpuct, set_cpuct_root, set_cvisits_selection, set_hash_size_mb, set_multi_pv, set_num_threads, set_policy_temperature, set_policy_temperature_root, }; use crate::search::Search; @@ -74,6 +74,7 @@ pub fn uci() { println!("id author {ENGINE_AUTHOR}"); println!("option name Hash type spin min 8 max 65536 default 16"); println!("option name Threads type spin min 1 max 255 default 1"); + println!("option name MultiPV type spin min 1 max 255 default 1"); println!("option name SyzygyPath type string default "); println!("option name CPuct type string default 1.06"); println!("option name CPuctRoot type string default 3.17"); @@ -141,6 +142,7 @@ impl UciOption { self.set_option(set_hash_size_mb); search.reset_table(); } + "multipv" => self.set_option(set_multi_pv), "cpuct" => self.set_option(set_cpuct), "cpuctroot" => self.set_option(set_cpuct_root), "cvisitsselection" => self.set_option(set_cvisits_selection),