diff --git a/.gitignore b/.gitignore index 971c055..5990bf8 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ train/logs train/.env train/__pycache__ train/pgn/* +train/pgn.all train_data* train/*_weights train/*_bias diff --git a/src/chess.rs b/src/chess.rs new file mode 100644 index 0000000..194194c --- /dev/null +++ b/src/chess.rs @@ -0,0 +1,180 @@ +use crate::options::is_chess960; +use arrayvec::ArrayVec; + +#[derive(Debug, Copy, Clone)] +pub struct Move(u16); + +pub type MoveList = ArrayVec; + +impl Move { + const SQ_MASK: u16 = 0b11_1111; + const TO_SHIFT: u16 = 6; + const PROMO_MASK: u16 = 0b11; + const PROMO_SHIFT: u16 = 12; + + const PROMO_FLAG: u16 = 0b1100_0000_0000_0000; + const ENPASSANT_FLAG: u16 = 0b0100_0000_0000_0000; + const CASTLE_FLAG: u16 = 0b1000_0000_0000_0000; + + pub fn new(from: shakmaty::Square, to: shakmaty::Square) -> Self { + Self(from as u16 | ((to as u16) << Self::TO_SHIFT)) + } + + pub fn new_promotion( + from: shakmaty::Square, + to: shakmaty::Square, + promotion: shakmaty::Role, + ) -> Self { + Self( + Self::new(from, to).0 + | Self::PROMO_FLAG + | ((role_to_promotion_idx(promotion)) << Self::PROMO_SHIFT), + ) + } + + pub fn new_enpassant(from: shakmaty::Square, to: shakmaty::Square) -> Self { + Self(Self::new(from, to).0 | Self::ENPASSANT_FLAG) + } + + pub fn new_castle(king: shakmaty::Square, rook: shakmaty::Square) -> Self { + Self(Self::new(king, rook).0 | Self::CASTLE_FLAG) + } + + pub fn from(self) -> shakmaty::Square { + unsafe { shakmaty::Square::new_unchecked(u32::from(self.0 & Self::SQ_MASK)) } + } + + pub fn to(self) -> shakmaty::Square { + unsafe { + shakmaty::Square::new_unchecked(u32::from((self.0 >> Self::TO_SHIFT) & Self::SQ_MASK)) + } + } + + pub fn is_normal(self) -> bool { + self.0 & Self::ENPASSANT_FLAG == 0 && self.0 & Self::CASTLE_FLAG == 0 + } + + pub fn is_enpassant(self) -> bool { + self.0 & Self::ENPASSANT_FLAG != 0 && self.0 & Self::CASTLE_FLAG == 0 + } + + pub fn is_castle(self) -> bool { + self.0 & Self::CASTLE_FLAG != 0 && self.0 & Self::ENPASSANT_FLAG == 0 + } + + fn is_promotion(self) -> bool { + self.0 & Self::PROMO_FLAG == Self::PROMO_FLAG + } + + pub fn promotion(self) -> Option { + if self.is_promotion() { + Some(promotion_idx_to_role( + (self.0 >> Self::PROMO_SHIFT) & Self::PROMO_MASK, + )) + } else { + None + } + } + + pub fn to_uci(self) -> String { + let from = self.from(); + let to = if self.is_castle() && !is_chess960() { + match self.to() { + shakmaty::Square::H1 => shakmaty::Square::G1, + shakmaty::Square::A1 => shakmaty::Square::C1, + shakmaty::Square::H8 => shakmaty::Square::G8, + shakmaty::Square::A8 => shakmaty::Square::C8, + _ => panic!("Invalid castle move: {self:?}"), + } + } else { + self.to() + }; + + let promotion = self + .promotion() + .map_or(String::new(), role_to_promotion_char); + + format!("{from}{to}{promotion}") + } + + pub fn to_shakmaty(self, board: &shakmaty::Board) -> shakmaty::Move { + let from = self.from(); + let to = self.to(); + + if self.is_enpassant() { + return shakmaty::Move::EnPassant { from, to }; + } + + if self.is_castle() { + return shakmaty::Move::Castle { + king: from, + rook: to, + }; + } + + let promotion = self.promotion(); + let role = board.role_at(from).unwrap(); + let capture = board.role_at(to); + + shakmaty::Move::Normal { + from, + to, + promotion, + role, + capture, + } + } +} + +impl From for Move { + fn from(m: shakmaty::Move) -> Self { + match m { + shakmaty::Move::Normal { + from, + to, + promotion: None, + .. + } => Self::new(from, to), + shakmaty::Move::Normal { + from, + to, + promotion: Some(promo), + .. + } => Self::new_promotion(from, to, promo), + shakmaty::Move::EnPassant { from, to } => Self::new_enpassant(from, to), + shakmaty::Move::Castle { king, rook } => Self::new_castle(king, rook), + _ => panic!("Invalid move: {m:?}"), + } + } +} + +fn role_to_promotion_char(role: shakmaty::Role) -> String { + match role { + shakmaty::Role::Queen => "q", + shakmaty::Role::Rook => "r", + shakmaty::Role::Bishop => "b", + shakmaty::Role::Knight => "n", + _ => panic!("Invalid promotion role: {role:?}"), + } + .to_string() +} + +fn role_to_promotion_idx(role: shakmaty::Role) -> u16 { + match role { + shakmaty::Role::Queen => 0, + shakmaty::Role::Rook => 1, + shakmaty::Role::Bishop => 2, + shakmaty::Role::Knight => 3, + _ => panic!("Invalid promotion role: {role:?}"), + } +} + +fn promotion_idx_to_role(idx: u16) -> shakmaty::Role { + match idx { + 0 => shakmaty::Role::Queen, + 1 => shakmaty::Role::Rook, + 2 => shakmaty::Role::Bishop, + 3 => shakmaty::Role::Knight, + _ => panic!("Invalid promotion index: {idx}"), + } +} diff --git a/src/evaluation.rs b/src/evaluation.rs index 2da4e6a..188fab3 100644 --- a/src/evaluation.rs +++ b/src/evaluation.rs @@ -1,5 +1,6 @@ -use shakmaty::{MoveList, Position}; +use shakmaty::Position; +use crate::chess::MoveList; use crate::math; use crate::search::SCALE; use crate::state::{self, State}; @@ -164,7 +165,7 @@ fn run_policy_net(state: &State, moves: &MoveList, t: f32) -> Vec { let mut acc = Vec::with_capacity(moves.len()); for m in moves { - let move_idx = state.move_to_index(m); + let move_idx = state.move_to_index(*m); move_idxs.push(move_idx); acc.push(( POLICY_NET.add_bias.vals[move_idx], diff --git a/src/main.rs b/src/main.rs index c6be41d..398602d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,18 +15,18 @@ extern crate rand; extern crate shakmaty; mod arena; +mod args; +mod chess; +mod evaluation; mod math; mod options; +mod search; mod search_tree; +mod state; mod tablebase; +mod training; mod transposition_table; mod tree_policy; - -mod args; -mod evaluation; -mod search; -mod state; -mod training; mod uci; fn main() { diff --git a/src/search.rs b/src/search.rs index 0fd5850..b9b01fa 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,7 +1,8 @@ -use shakmaty::{CastlingMode, Color, Move}; +use shakmaty::{CastlingMode, Color}; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{Duration, Instant}; +use crate::chess; use crate::evaluation; use crate::options::{get_num_threads, get_policy_temperature, is_chess960}; use crate::search_tree::{MoveEdge, SearchTree}; @@ -122,7 +123,7 @@ impl Search { let mvs = state.available_moves(); if mvs.len() == 1 { - let uci_mv = to_uci(&mvs[0]); + let uci_mv = mvs[0].to_uci(); println!("info depth 1 seldepth 1 nodes 1 nps 1 tbhits 0 time 1 pv {uci_mv}"); println!("bestmove {uci_mv}"); return; @@ -235,7 +236,7 @@ impl Search { run_search_thread(&time_management); self.search_tree.print_info(&time_management); stop_signal.store(true, Ordering::Relaxed); - println!("bestmove {}", to_uci(&self.best_move().unwrap())); + println!("bestmove {}", self.best_move().to_uci()); }); for _ in 0..(num_threads - 1) { @@ -280,8 +281,8 @@ impl Search { moves.sort_by_key(|(h, e)| (h.average_reward().unwrap_or(*e) * SCALE) as i64); for (mov, e) in moves { println!( - "info string {:7} M: {:>6} P: {:>6} V: {:7} E: {:>6} ({:>8})", - format!("{}", mov.get_move()), + "info string {:7} M: {:5} P: {:>6} V: {:7} E: {:>6} ({:>8})", + format!("{}", mov.get_move().to_uci()), format!("{:3.2}", e * 100.), format!("{:3.2}", f32::from(mov.policy()) / SCALE * 100.), mov.visits(), @@ -293,21 +294,21 @@ impl Search { } } - pub fn principal_variation(&self, num_moves: usize) -> Vec { + pub fn principal_variation(&self, num_moves: usize) -> Vec { self.search_tree .principal_variation(num_moves) .into_iter() .map(MoveEdge::get_move) - .cloned() + .copied() .collect() } - pub fn best_move(&self) -> Option { - self.principal_variation(1).get(0).cloned() + pub fn best_move(&self) -> chess::Move { + *self.principal_variation(1).get(0).unwrap() } } -pub fn to_uci(mov: &Move) -> String { +pub fn to_uci(mov: &shakmaty::Move) -> String { mov.to_uci(CastlingMode::from_chess960(is_chess960())) .to_string() } diff --git a/src/search_tree.rs b/src/search_tree.rs index 2c5f633..c77b5c7 100644 --- a/src/search_tree.rs +++ b/src/search_tree.rs @@ -8,13 +8,14 @@ use std::sync::atomic::{ }; use crate::arena::Error as ArenaError; +use crate::chess; use crate::evaluation::{self, Flag}; use crate::options::{ get_cpuct, get_cpuct_root, get_cvisits_selection, get_policy_temperature, get_policy_temperature_root, }; use crate::search::{eval_in_cp, ThreadData}; -use crate::search::{to_uci, TimeManagement, SCALE}; +use crate::search::{TimeManagement, SCALE}; use crate::state::State; use crate::transposition_table::{LRAllocator, LRTable, TranspositionTable}; use crate::tree_policy; @@ -46,7 +47,7 @@ pub struct MoveEdge { sum_evaluations: AtomicI64, visits: AtomicU32, policy: u16, - mov: shakmaty::Move, + mov: chess::Move, child: AtomicPtr, } @@ -95,7 +96,7 @@ impl PositionNode { } impl MoveEdge { - fn new(policy: u16, mov: shakmaty::Move) -> Self { + fn new(policy: u16, mov: chess::Move) -> Self { Self { policy, sum_evaluations: AtomicI64::default(), @@ -105,7 +106,7 @@ impl MoveEdge { } } - pub fn get_move(&self) -> &shakmaty::Move { + pub fn get_move(&self) -> &chess::Move { &self.mov } @@ -193,7 +194,7 @@ where #[allow(clippy::cast_sign_loss)] for (i, x) in hots.iter_mut().enumerate() { - *x = MoveEdge::new((move_eval[i] * SCALE) as u16, moves[i].clone()); + *x = MoveEdge::new((move_eval[i] * SCALE) as u16, moves[i]); } Ok(PositionNode::new(hots, state_flag)) @@ -303,7 +304,7 @@ impl SearchTree { let choice = tree_policy::choose_child(node.hots(), cpuct, fpu); choice.down(); path.push(choice); - state.make_move(&choice.mov); + state.make_move(choice.mov); if choice.visits() == 1 { evaln = evaluation::evaluate_state(&state); @@ -453,7 +454,7 @@ impl SearchTree { let sel_depth = self.max_depth(); let pv = self.principal_variation(depth.max(1)); let pv_string: String = pv.into_iter().fold(String::new(), |mut out, x| { - write!(out, " {}", to_uci(x.get_move())).unwrap(); + write!(out, " {}", x.get_move().to_uci()).unwrap(); out }); diff --git a/src/state.rs b/src/state.rs index d62754b..66d113d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -3,11 +3,12 @@ use shakmaty::fen::Fen; use shakmaty::uci::Uci; use shakmaty::zobrist::{Zobrist64, ZobristHash, ZobristValue}; use shakmaty::{ - self, CastlingMode, CastlingSide, Chess, Color, EnPassantMode, File, Move, MoveList, Piece, - Position, Rank, Role, Square, + self, CastlingMode, CastlingSide, Chess, Color, EnPassantMode, File, Piece, Position, Rank, + Role, Square, }; use std::convert::Into; +use crate::chess; use crate::options::is_chess960; use crate::uci::Tokens; @@ -21,7 +22,7 @@ pub const NUMBER_MOVE_IDX: usize = 384; pub struct Builder { initial_state: Chess, crnt_state: Chess, - moves: Vec, + moves: Vec, } impl Builder { @@ -29,7 +30,7 @@ impl Builder { &self.crnt_state } - pub fn make_move(&mut self, mov: Move) { + pub fn make_move(&mut self, mov: shakmaty::Move) { self.crnt_state = self.crnt_state.clone().play(&mov).unwrap(); self.moves.push(mov); } @@ -71,7 +72,7 @@ impl Builder { Some(result) } - pub fn extract(&self) -> (State, Vec) { + pub fn extract(&self) -> (State, Vec) { let state = Self::from(self.initial_state.clone()).into(); let moves = self.moves.clone(); (state, moves) @@ -86,7 +87,7 @@ pub struct State { // can go above this. Hence we add a little space prev_state_hashes: ArrayVec, - prev_moves: [Option; 2], + prev_moves: [Option; 2], repetitions: usize, hash: Zobrist64, @@ -108,24 +109,35 @@ impl State { self.hash.0 } - pub fn available_moves(&self) -> MoveList { - self.board.legal_moves() + pub fn available_moves(&self) -> chess::MoveList { + let mut moves = chess::MoveList::new(); + + for m in self.board.legal_moves() { + moves.push(m.into()); + } + + moves } - pub fn make_move(&mut self, mov: &Move) { - let is_pawn_move = mov.role() == Role::Pawn; + pub fn make_move(&mut self, mov: chess::Move) { + let b = self.board.board(); + let role = b.role_at(mov.from()).unwrap(); - self.prev_moves[0] = self.prev_moves[1].clone(); - self.prev_moves[1] = Some(mov.clone()); + let is_pawn_move = role == Role::Pawn; + let capture = b.role_at(mov.to()); - if is_pawn_move || mov.is_capture() { + self.prev_moves[0] = self.prev_moves[1]; + self.prev_moves[1] = Some(mov); + + if is_pawn_move || capture.is_some() { self.prev_state_hashes.clear(); } self.prev_state_hashes.push(self.hash()); self.update_hash_pre(); - self.board.play_unchecked(mov); - self.update_hash(!self.side_to_move(), mov); + self.board + .play_unchecked(&mov.to_shakmaty(self.board.board())); + self.update_hash(!self.side_to_move(), role, capture, mov); self.check_for_repetition(); } @@ -148,49 +160,52 @@ impl State { } } - fn update_hash(&mut self, color: Color, mv: &Move) { - match mv { - Move::Normal { - role, - from, + fn update_hash( + &mut self, + color: Color, + role: shakmaty::Role, + capture: Option, + mv: chess::Move, + ) { + if !mv.is_normal() { + self.hash = self.board.zobrist_hash(EnPassantMode::Always); + return; + } + + let from = mv.from(); + let to = mv.to(); + + let pc = Piece { color, role }; + self.hash ^= Zobrist64::zobrist_for_piece(from, pc); + self.hash ^= Zobrist64::zobrist_for_piece(to, pc); + + if let Some(captured) = capture { + self.hash ^= Zobrist64::zobrist_for_piece( to, - capture, - promotion: None, - } => { - let pc = Piece { color, role: *role }; - self.hash ^= Zobrist64::zobrist_for_piece(*from, pc); - self.hash ^= Zobrist64::zobrist_for_piece(*to, pc); - - if let Some(captured) = capture { - self.hash ^= Zobrist64::zobrist_for_piece( - *to, - Piece { - color: !color, - role: *captured, - }, - ); - } + Piece { + color: !color, + role: captured, + }, + ); + } - if let Some(ep_sq) = self.board.ep_square(EnPassantMode::Always) { - self.hash ^= Zobrist64::zobrist_for_en_passant_file(ep_sq.file()); - } + if let Some(ep_sq) = self.board.ep_square(EnPassantMode::Always) { + self.hash ^= Zobrist64::zobrist_for_en_passant_file(ep_sq.file()); + } - let castles = self.board.castles(); + let castles = self.board.castles(); - if !castles.is_empty() { - for color in Color::ALL { - for side in CastlingSide::ALL { - if castles.has(color, side) { - self.hash ^= Zobrist64::zobrist_for_castling_right(color, side); - } - } + if !castles.is_empty() { + for color in Color::ALL { + for side in CastlingSide::ALL { + if castles.has(color, side) { + self.hash ^= Zobrist64::zobrist_for_castling_right(color, side); } } - - self.hash ^= Zobrist64::zobrist_for_white_turn(); } - _ => self.hash = self.board.zobrist_hash(EnPassantMode::Always), - }; + } + + self.hash ^= Zobrist64::zobrist_for_white_turn(); } fn check_for_repetition(&mut self) { @@ -272,12 +287,12 @@ impl State { // We use the king threats and defenses squares for previous moves if let Some(m) = &self.prev_moves[0] { f(OFFSET_DEFENDS + feature_idx(m.to(), Role::King, stm)); - f(OFFSET_DEFENDS + feature_idx(m.from().unwrap(), Role::King, !stm)); + f(OFFSET_DEFENDS + feature_idx(m.from(), Role::King, !stm)); } if let Some(m) = &self.prev_moves[1] { f(OFFSET_THREATS + feature_idx(m.to(), Role::King, stm)); - f(OFFSET_THREATS + feature_idx(m.from().unwrap(), Role::King, !stm)); + f(OFFSET_THREATS + feature_idx(m.from(), Role::King, !stm)); } } @@ -302,7 +317,8 @@ impl State { self.features_map(f); } - pub fn move_to_index(&self, mv: &Move) -> usize { + pub fn move_to_index(&self, mv: chess::Move) -> usize { + let role = self.board.board().role_at(mv.from()).unwrap(); let to_sq = mv.to(); let (flip_vertical, flip_horizontal) = self.feature_flip(); @@ -314,7 +330,7 @@ impl State { (false, false) => sq, }; - let role_idx = mv.role() as usize - 1; + let role_idx = role as usize - 1; let flip_to = flip_square(to_sq); @@ -365,7 +381,7 @@ impl From for State { }; for mov in sb.moves { - state.make_move(&mov); + state.make_move(mov.into()); } state diff --git a/src/training.rs b/src/training.rs index dbc4bde..4ede291 100644 --- a/src/training.rs +++ b/src/training.rs @@ -172,10 +172,10 @@ impl Visitor for ValueDataGenerator { }); for m in moves.as_slice() { - move_features[state.move_to_index(m)] = 2; + move_features[state.move_to_index(*m)] = 2; } - move_features[state.move_to_index(&made)] = 1; + move_features[state.move_to_index(made.clone().into())] = 1; let mut f_vec = Vec::with_capacity(1 + state::NUMBER_MOVE_IDX + state::NUMBER_FEATURES); @@ -184,7 +184,7 @@ impl Visitor for ValueDataGenerator { write_libsvm(&f_vec, &mut self.out_file, wdl); } - state.make_move(&made); + state.make_move(made.into()); } }