From 018545a0de35a427cffc8f632f00083d7b5cbee0 Mon Sep 17 00:00:00 2001
From: ianagbip1oti <ianagbip1oti@gmail.com>
Date: Mon, 18 Dec 2023 09:35:56 +0000
Subject: [PATCH] =?UTF-8?q?=F0=9F=93=A6=20NEW:=20Use=2016=20bit=20move=20s?=
 =?UTF-8?q?truct?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .gitignore         |   1 +
 src/chess.rs       | 180 +++++++++++++++++++++++++++++++++++++++++++++
 src/evaluation.rs  |   5 +-
 src/main.rs        |  12 +--
 src/search.rs      |  21 +++---
 src/search_tree.rs |  15 ++--
 src/state.rs       | 128 ++++++++++++++++++--------------
 src/training.rs    |   6 +-
 8 files changed, 284 insertions(+), 84 deletions(-)
 create mode 100644 src/chess.rs

diff --git a/.gitignore b/.gitignore
index 971c055..5990bf8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -28,6 +28,7 @@ train/logs
 train/.env
 train/__pycache__
 train/pgn/*
+train/pgn.all
 train_data*
 train/*_weights
 train/*_bias
diff --git a/src/chess.rs b/src/chess.rs
new file mode 100644
index 0000000..194194c
--- /dev/null
+++ b/src/chess.rs
@@ -0,0 +1,180 @@
+use crate::options::is_chess960;
+use arrayvec::ArrayVec;
+
+#[derive(Debug, Copy, Clone)]
+pub struct Move(u16);
+
+pub type MoveList = ArrayVec<Move, 256>;
+
+impl Move {
+    const SQ_MASK: u16 = 0b11_1111;
+    const TO_SHIFT: u16 = 6;
+    const PROMO_MASK: u16 = 0b11;
+    const PROMO_SHIFT: u16 = 12;
+
+    const PROMO_FLAG: u16 = 0b1100_0000_0000_0000;
+    const ENPASSANT_FLAG: u16 = 0b0100_0000_0000_0000;
+    const CASTLE_FLAG: u16 = 0b1000_0000_0000_0000;
+
+    pub fn new(from: shakmaty::Square, to: shakmaty::Square) -> Self {
+        Self(from as u16 | ((to as u16) << Self::TO_SHIFT))
+    }
+
+    pub fn new_promotion(
+        from: shakmaty::Square,
+        to: shakmaty::Square,
+        promotion: shakmaty::Role,
+    ) -> Self {
+        Self(
+            Self::new(from, to).0
+                | Self::PROMO_FLAG
+                | ((role_to_promotion_idx(promotion)) << Self::PROMO_SHIFT),
+        )
+    }
+
+    pub fn new_enpassant(from: shakmaty::Square, to: shakmaty::Square) -> Self {
+        Self(Self::new(from, to).0 | Self::ENPASSANT_FLAG)
+    }
+
+    pub fn new_castle(king: shakmaty::Square, rook: shakmaty::Square) -> Self {
+        Self(Self::new(king, rook).0 | Self::CASTLE_FLAG)
+    }
+
+    pub fn from(self) -> shakmaty::Square {
+        unsafe { shakmaty::Square::new_unchecked(u32::from(self.0 & Self::SQ_MASK)) }
+    }
+
+    pub fn to(self) -> shakmaty::Square {
+        unsafe {
+            shakmaty::Square::new_unchecked(u32::from((self.0 >> Self::TO_SHIFT) & Self::SQ_MASK))
+        }
+    }
+
+    pub fn is_normal(self) -> bool {
+        self.0 & Self::ENPASSANT_FLAG == 0 && self.0 & Self::CASTLE_FLAG == 0
+    }
+
+    pub fn is_enpassant(self) -> bool {
+        self.0 & Self::ENPASSANT_FLAG != 0 && self.0 & Self::CASTLE_FLAG == 0
+    }
+
+    pub fn is_castle(self) -> bool {
+        self.0 & Self::CASTLE_FLAG != 0 && self.0 & Self::ENPASSANT_FLAG == 0
+    }
+
+    fn is_promotion(self) -> bool {
+        self.0 & Self::PROMO_FLAG == Self::PROMO_FLAG
+    }
+
+    pub fn promotion(self) -> Option<shakmaty::Role> {
+        if self.is_promotion() {
+            Some(promotion_idx_to_role(
+                (self.0 >> Self::PROMO_SHIFT) & Self::PROMO_MASK,
+            ))
+        } else {
+            None
+        }
+    }
+
+    pub fn to_uci(self) -> String {
+        let from = self.from();
+        let to = if self.is_castle() && !is_chess960() {
+            match self.to() {
+                shakmaty::Square::H1 => shakmaty::Square::G1,
+                shakmaty::Square::A1 => shakmaty::Square::C1,
+                shakmaty::Square::H8 => shakmaty::Square::G8,
+                shakmaty::Square::A8 => shakmaty::Square::C8,
+                _ => panic!("Invalid castle move: {self:?}"),
+            }
+        } else {
+            self.to()
+        };
+
+        let promotion = self
+            .promotion()
+            .map_or(String::new(), role_to_promotion_char);
+
+        format!("{from}{to}{promotion}")
+    }
+
+    pub fn to_shakmaty(self, board: &shakmaty::Board) -> shakmaty::Move {
+        let from = self.from();
+        let to = self.to();
+
+        if self.is_enpassant() {
+            return shakmaty::Move::EnPassant { from, to };
+        }
+
+        if self.is_castle() {
+            return shakmaty::Move::Castle {
+                king: from,
+                rook: to,
+            };
+        }
+
+        let promotion = self.promotion();
+        let role = board.role_at(from).unwrap();
+        let capture = board.role_at(to);
+
+        shakmaty::Move::Normal {
+            from,
+            to,
+            promotion,
+            role,
+            capture,
+        }
+    }
+}
+
+impl From<shakmaty::Move> for Move {
+    fn from(m: shakmaty::Move) -> Self {
+        match m {
+            shakmaty::Move::Normal {
+                from,
+                to,
+                promotion: None,
+                ..
+            } => Self::new(from, to),
+            shakmaty::Move::Normal {
+                from,
+                to,
+                promotion: Some(promo),
+                ..
+            } => Self::new_promotion(from, to, promo),
+            shakmaty::Move::EnPassant { from, to } => Self::new_enpassant(from, to),
+            shakmaty::Move::Castle { king, rook } => Self::new_castle(king, rook),
+            _ => panic!("Invalid move: {m:?}"),
+        }
+    }
+}
+
+fn role_to_promotion_char(role: shakmaty::Role) -> String {
+    match role {
+        shakmaty::Role::Queen => "q",
+        shakmaty::Role::Rook => "r",
+        shakmaty::Role::Bishop => "b",
+        shakmaty::Role::Knight => "n",
+        _ => panic!("Invalid promotion role: {role:?}"),
+    }
+    .to_string()
+}
+
+fn role_to_promotion_idx(role: shakmaty::Role) -> u16 {
+    match role {
+        shakmaty::Role::Queen => 0,
+        shakmaty::Role::Rook => 1,
+        shakmaty::Role::Bishop => 2,
+        shakmaty::Role::Knight => 3,
+        _ => panic!("Invalid promotion role: {role:?}"),
+    }
+}
+
+fn promotion_idx_to_role(idx: u16) -> shakmaty::Role {
+    match idx {
+        0 => shakmaty::Role::Queen,
+        1 => shakmaty::Role::Rook,
+        2 => shakmaty::Role::Bishop,
+        3 => shakmaty::Role::Knight,
+        _ => panic!("Invalid promotion index: {idx}"),
+    }
+}
diff --git a/src/evaluation.rs b/src/evaluation.rs
index 2da4e6a..188fab3 100644
--- a/src/evaluation.rs
+++ b/src/evaluation.rs
@@ -1,5 +1,6 @@
-use shakmaty::{MoveList, Position};
+use shakmaty::Position;
 
+use crate::chess::MoveList;
 use crate::math;
 use crate::search::SCALE;
 use crate::state::{self, State};
@@ -164,7 +165,7 @@ fn run_policy_net(state: &State, moves: &MoveList, t: f32) -> Vec<f32> {
     let mut acc = Vec::with_capacity(moves.len());
 
     for m in moves {
-        let move_idx = state.move_to_index(m);
+        let move_idx = state.move_to_index(*m);
         move_idxs.push(move_idx);
         acc.push((
             POLICY_NET.add_bias.vals[move_idx],
diff --git a/src/main.rs b/src/main.rs
index c6be41d..398602d 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -15,18 +15,18 @@ extern crate rand;
 extern crate shakmaty;
 
 mod arena;
+mod args;
+mod chess;
+mod evaluation;
 mod math;
 mod options;
+mod search;
 mod search_tree;
+mod state;
 mod tablebase;
+mod training;
 mod transposition_table;
 mod tree_policy;
-
-mod args;
-mod evaluation;
-mod search;
-mod state;
-mod training;
 mod uci;
 
 fn main() {
diff --git a/src/search.rs b/src/search.rs
index 0fd5850..b9b01fa 100644
--- a/src/search.rs
+++ b/src/search.rs
@@ -1,7 +1,8 @@
-use shakmaty::{CastlingMode, Color, Move};
+use shakmaty::{CastlingMode, Color};
 use std::sync::atomic::{AtomicBool, Ordering};
 use std::time::{Duration, Instant};
 
+use crate::chess;
 use crate::evaluation;
 use crate::options::{get_num_threads, get_policy_temperature, is_chess960};
 use crate::search_tree::{MoveEdge, SearchTree};
@@ -122,7 +123,7 @@ impl Search {
         let mvs = state.available_moves();
 
         if mvs.len() == 1 {
-            let uci_mv = to_uci(&mvs[0]);
+            let uci_mv = mvs[0].to_uci();
             println!("info depth 1 seldepth 1 nodes 1 nps 1 tbhits 0 time 1 pv {uci_mv}");
             println!("bestmove {uci_mv}");
             return;
@@ -235,7 +236,7 @@ impl Search {
                 run_search_thread(&time_management);
                 self.search_tree.print_info(&time_management);
                 stop_signal.store(true, Ordering::Relaxed);
-                println!("bestmove {}", to_uci(&self.best_move().unwrap()));
+                println!("bestmove {}", self.best_move().to_uci());
             });
 
             for _ in 0..(num_threads - 1) {
@@ -280,8 +281,8 @@ impl Search {
         moves.sort_by_key(|(h, e)| (h.average_reward().unwrap_or(*e) * SCALE) as i64);
         for (mov, e) in moves {
             println!(
-                "info string {:7} M: {:>6} P: {:>6} V: {:7} E: {:>6} ({:>8})",
-                format!("{}", mov.get_move()),
+                "info string {:7} M: {:5} P: {:>6} V: {:7} E: {:>6} ({:>8})",
+                format!("{}", mov.get_move().to_uci()),
                 format!("{:3.2}", e * 100.),
                 format!("{:3.2}", f32::from(mov.policy()) / SCALE * 100.),
                 mov.visits(),
@@ -293,21 +294,21 @@ impl Search {
         }
     }
 
-    pub fn principal_variation(&self, num_moves: usize) -> Vec<Move> {
+    pub fn principal_variation(&self, num_moves: usize) -> Vec<chess::Move> {
         self.search_tree
             .principal_variation(num_moves)
             .into_iter()
             .map(MoveEdge::get_move)
-            .cloned()
+            .copied()
             .collect()
     }
 
-    pub fn best_move(&self) -> Option<shakmaty::Move> {
-        self.principal_variation(1).get(0).cloned()
+    pub fn best_move(&self) -> chess::Move {
+        *self.principal_variation(1).get(0).unwrap()
     }
 }
 
-pub fn to_uci(mov: &Move) -> String {
+pub fn to_uci(mov: &shakmaty::Move) -> String {
     mov.to_uci(CastlingMode::from_chess960(is_chess960()))
         .to_string()
 }
diff --git a/src/search_tree.rs b/src/search_tree.rs
index 2c5f633..c77b5c7 100644
--- a/src/search_tree.rs
+++ b/src/search_tree.rs
@@ -8,13 +8,14 @@ use std::sync::atomic::{
 };
 
 use crate::arena::Error as ArenaError;
+use crate::chess;
 use crate::evaluation::{self, Flag};
 use crate::options::{
     get_cpuct, get_cpuct_root, get_cvisits_selection, get_policy_temperature,
     get_policy_temperature_root,
 };
 use crate::search::{eval_in_cp, ThreadData};
-use crate::search::{to_uci, TimeManagement, SCALE};
+use crate::search::{TimeManagement, SCALE};
 use crate::state::State;
 use crate::transposition_table::{LRAllocator, LRTable, TranspositionTable};
 use crate::tree_policy;
@@ -46,7 +47,7 @@ pub struct MoveEdge {
     sum_evaluations: AtomicI64,
     visits: AtomicU32,
     policy: u16,
-    mov: shakmaty::Move,
+    mov: chess::Move,
     child: AtomicPtr<PositionNode>,
 }
 
@@ -95,7 +96,7 @@ impl PositionNode {
 }
 
 impl MoveEdge {
-    fn new(policy: u16, mov: shakmaty::Move) -> Self {
+    fn new(policy: u16, mov: chess::Move) -> Self {
         Self {
             policy,
             sum_evaluations: AtomicI64::default(),
@@ -105,7 +106,7 @@ impl MoveEdge {
         }
     }
 
-    pub fn get_move(&self) -> &shakmaty::Move {
+    pub fn get_move(&self) -> &chess::Move {
         &self.mov
     }
 
@@ -193,7 +194,7 @@ where
 
     #[allow(clippy::cast_sign_loss)]
     for (i, x) in hots.iter_mut().enumerate() {
-        *x = MoveEdge::new((move_eval[i] * SCALE) as u16, moves[i].clone());
+        *x = MoveEdge::new((move_eval[i] * SCALE) as u16, moves[i]);
     }
 
     Ok(PositionNode::new(hots, state_flag))
@@ -303,7 +304,7 @@ impl SearchTree {
             let choice = tree_policy::choose_child(node.hots(), cpuct, fpu);
             choice.down();
             path.push(choice);
-            state.make_move(&choice.mov);
+            state.make_move(choice.mov);
 
             if choice.visits() == 1 {
                 evaln = evaluation::evaluate_state(&state);
@@ -453,7 +454,7 @@ impl SearchTree {
         let sel_depth = self.max_depth();
         let pv = self.principal_variation(depth.max(1));
         let pv_string: String = pv.into_iter().fold(String::new(), |mut out, x| {
-            write!(out, " {}", to_uci(x.get_move())).unwrap();
+            write!(out, " {}", x.get_move().to_uci()).unwrap();
             out
         });
 
diff --git a/src/state.rs b/src/state.rs
index d62754b..66d113d 100644
--- a/src/state.rs
+++ b/src/state.rs
@@ -3,11 +3,12 @@ use shakmaty::fen::Fen;
 use shakmaty::uci::Uci;
 use shakmaty::zobrist::{Zobrist64, ZobristHash, ZobristValue};
 use shakmaty::{
-    self, CastlingMode, CastlingSide, Chess, Color, EnPassantMode, File, Move, MoveList, Piece,
-    Position, Rank, Role, Square,
+    self, CastlingMode, CastlingSide, Chess, Color, EnPassantMode, File, Piece, Position, Rank,
+    Role, Square,
 };
 use std::convert::Into;
 
+use crate::chess;
 use crate::options::is_chess960;
 use crate::uci::Tokens;
 
@@ -21,7 +22,7 @@ pub const NUMBER_MOVE_IDX: usize = 384;
 pub struct Builder {
     initial_state: Chess,
     crnt_state: Chess,
-    moves: Vec<Move>,
+    moves: Vec<shakmaty::Move>,
 }
 
 impl Builder {
@@ -29,7 +30,7 @@ impl Builder {
         &self.crnt_state
     }
 
-    pub fn make_move(&mut self, mov: Move) {
+    pub fn make_move(&mut self, mov: shakmaty::Move) {
         self.crnt_state = self.crnt_state.clone().play(&mov).unwrap();
         self.moves.push(mov);
     }
@@ -71,7 +72,7 @@ impl Builder {
         Some(result)
     }
 
-    pub fn extract(&self) -> (State, Vec<Move>) {
+    pub fn extract(&self) -> (State, Vec<shakmaty::Move>) {
         let state = Self::from(self.initial_state.clone()).into();
         let moves = self.moves.clone();
         (state, moves)
@@ -86,7 +87,7 @@ pub struct State {
     // can go above this. Hence we add a little space
     prev_state_hashes: ArrayVec<u64, 128>,
 
-    prev_moves: [Option<Move>; 2],
+    prev_moves: [Option<chess::Move>; 2],
 
     repetitions: usize,
     hash: Zobrist64,
@@ -108,24 +109,35 @@ impl State {
         self.hash.0
     }
 
-    pub fn available_moves(&self) -> MoveList {
-        self.board.legal_moves()
+    pub fn available_moves(&self) -> chess::MoveList {
+        let mut moves = chess::MoveList::new();
+
+        for m in self.board.legal_moves() {
+            moves.push(m.into());
+        }
+
+        moves
     }
 
-    pub fn make_move(&mut self, mov: &Move) {
-        let is_pawn_move = mov.role() == Role::Pawn;
+    pub fn make_move(&mut self, mov: chess::Move) {
+        let b = self.board.board();
+        let role = b.role_at(mov.from()).unwrap();
 
-        self.prev_moves[0] = self.prev_moves[1].clone();
-        self.prev_moves[1] = Some(mov.clone());
+        let is_pawn_move = role == Role::Pawn;
+        let capture = b.role_at(mov.to());
 
-        if is_pawn_move || mov.is_capture() {
+        self.prev_moves[0] = self.prev_moves[1];
+        self.prev_moves[1] = Some(mov);
+
+        if is_pawn_move || capture.is_some() {
             self.prev_state_hashes.clear();
         }
         self.prev_state_hashes.push(self.hash());
 
         self.update_hash_pre();
-        self.board.play_unchecked(mov);
-        self.update_hash(!self.side_to_move(), mov);
+        self.board
+            .play_unchecked(&mov.to_shakmaty(self.board.board()));
+        self.update_hash(!self.side_to_move(), role, capture, mov);
 
         self.check_for_repetition();
     }
@@ -148,49 +160,52 @@ impl State {
         }
     }
 
-    fn update_hash(&mut self, color: Color, mv: &Move) {
-        match mv {
-            Move::Normal {
-                role,
-                from,
+    fn update_hash(
+        &mut self,
+        color: Color,
+        role: shakmaty::Role,
+        capture: Option<shakmaty::Role>,
+        mv: chess::Move,
+    ) {
+        if !mv.is_normal() {
+            self.hash = self.board.zobrist_hash(EnPassantMode::Always);
+            return;
+        }
+
+        let from = mv.from();
+        let to = mv.to();
+
+        let pc = Piece { color, role };
+        self.hash ^= Zobrist64::zobrist_for_piece(from, pc);
+        self.hash ^= Zobrist64::zobrist_for_piece(to, pc);
+
+        if let Some(captured) = capture {
+            self.hash ^= Zobrist64::zobrist_for_piece(
                 to,
-                capture,
-                promotion: None,
-            } => {
-                let pc = Piece { color, role: *role };
-                self.hash ^= Zobrist64::zobrist_for_piece(*from, pc);
-                self.hash ^= Zobrist64::zobrist_for_piece(*to, pc);
-
-                if let Some(captured) = capture {
-                    self.hash ^= Zobrist64::zobrist_for_piece(
-                        *to,
-                        Piece {
-                            color: !color,
-                            role: *captured,
-                        },
-                    );
-                }
+                Piece {
+                    color: !color,
+                    role: captured,
+                },
+            );
+        }
 
-                if let Some(ep_sq) = self.board.ep_square(EnPassantMode::Always) {
-                    self.hash ^= Zobrist64::zobrist_for_en_passant_file(ep_sq.file());
-                }
+        if let Some(ep_sq) = self.board.ep_square(EnPassantMode::Always) {
+            self.hash ^= Zobrist64::zobrist_for_en_passant_file(ep_sq.file());
+        }
 
-                let castles = self.board.castles();
+        let castles = self.board.castles();
 
-                if !castles.is_empty() {
-                    for color in Color::ALL {
-                        for side in CastlingSide::ALL {
-                            if castles.has(color, side) {
-                                self.hash ^= Zobrist64::zobrist_for_castling_right(color, side);
-                            }
-                        }
+        if !castles.is_empty() {
+            for color in Color::ALL {
+                for side in CastlingSide::ALL {
+                    if castles.has(color, side) {
+                        self.hash ^= Zobrist64::zobrist_for_castling_right(color, side);
                     }
                 }
-
-                self.hash ^= Zobrist64::zobrist_for_white_turn();
             }
-            _ => self.hash = self.board.zobrist_hash(EnPassantMode::Always),
-        };
+        }
+
+        self.hash ^= Zobrist64::zobrist_for_white_turn();
     }
 
     fn check_for_repetition(&mut self) {
@@ -272,12 +287,12 @@ impl State {
         // We use the king threats and defenses squares for previous moves
         if let Some(m) = &self.prev_moves[0] {
             f(OFFSET_DEFENDS + feature_idx(m.to(), Role::King, stm));
-            f(OFFSET_DEFENDS + feature_idx(m.from().unwrap(), Role::King, !stm));
+            f(OFFSET_DEFENDS + feature_idx(m.from(), Role::King, !stm));
         }
 
         if let Some(m) = &self.prev_moves[1] {
             f(OFFSET_THREATS + feature_idx(m.to(), Role::King, stm));
-            f(OFFSET_THREATS + feature_idx(m.from().unwrap(), Role::King, !stm));
+            f(OFFSET_THREATS + feature_idx(m.from(), Role::King, !stm));
         }
     }
 
@@ -302,7 +317,8 @@ impl State {
         self.features_map(f);
     }
 
-    pub fn move_to_index(&self, mv: &Move) -> usize {
+    pub fn move_to_index(&self, mv: chess::Move) -> usize {
+        let role = self.board.board().role_at(mv.from()).unwrap();
         let to_sq = mv.to();
 
         let (flip_vertical, flip_horizontal) = self.feature_flip();
@@ -314,7 +330,7 @@ impl State {
             (false, false) => sq,
         };
 
-        let role_idx = mv.role() as usize - 1;
+        let role_idx = role as usize - 1;
 
         let flip_to = flip_square(to_sq);
 
@@ -365,7 +381,7 @@ impl From<Builder> for State {
         };
 
         for mov in sb.moves {
-            state.make_move(&mov);
+            state.make_move(mov.into());
         }
 
         state
diff --git a/src/training.rs b/src/training.rs
index dbc4bde..4ede291 100644
--- a/src/training.rs
+++ b/src/training.rs
@@ -172,10 +172,10 @@ impl Visitor for ValueDataGenerator {
                 });
 
                 for m in moves.as_slice() {
-                    move_features[state.move_to_index(m)] = 2;
+                    move_features[state.move_to_index(*m)] = 2;
                 }
 
-                move_features[state.move_to_index(&made)] = 1;
+                move_features[state.move_to_index(made.clone().into())] = 1;
 
                 let mut f_vec =
                     Vec::with_capacity(1 + state::NUMBER_MOVE_IDX + state::NUMBER_FEATURES);
@@ -184,7 +184,7 @@ impl Visitor for ValueDataGenerator {
 
                 write_libsvm(&f_vec, &mut self.out_file, wdl);
             }
-            state.make_move(&made);
+            state.make_move(made.into());
         }
     }