Skip to content

Commit

Permalink
📦 NEW: Use 16 bit move struct
Browse files Browse the repository at this point in the history
  • Loading branch information
ianagbip1oti committed Dec 18, 2023
1 parent f477d20 commit 018545a
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 84 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ train/logs
train/.env
train/__pycache__
train/pgn/*
train/pgn.all
train_data*
train/*_weights
train/*_bias
Expand Down
180 changes: 180 additions & 0 deletions src/chess.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
use crate::options::is_chess960;
use arrayvec::ArrayVec;

#[derive(Debug, Copy, Clone)]
pub struct Move(u16);

pub type MoveList = ArrayVec<Move, 256>;

impl Move {
const SQ_MASK: u16 = 0b11_1111;
const TO_SHIFT: u16 = 6;
const PROMO_MASK: u16 = 0b11;
const PROMO_SHIFT: u16 = 12;

const PROMO_FLAG: u16 = 0b1100_0000_0000_0000;
const ENPASSANT_FLAG: u16 = 0b0100_0000_0000_0000;
const CASTLE_FLAG: u16 = 0b1000_0000_0000_0000;

pub fn new(from: shakmaty::Square, to: shakmaty::Square) -> Self {
Self(from as u16 | ((to as u16) << Self::TO_SHIFT))
}

pub fn new_promotion(
from: shakmaty::Square,
to: shakmaty::Square,
promotion: shakmaty::Role,
) -> Self {
Self(
Self::new(from, to).0
| Self::PROMO_FLAG
| ((role_to_promotion_idx(promotion)) << Self::PROMO_SHIFT),
)
}

pub fn new_enpassant(from: shakmaty::Square, to: shakmaty::Square) -> Self {
Self(Self::new(from, to).0 | Self::ENPASSANT_FLAG)
}

pub fn new_castle(king: shakmaty::Square, rook: shakmaty::Square) -> Self {
Self(Self::new(king, rook).0 | Self::CASTLE_FLAG)
}

pub fn from(self) -> shakmaty::Square {
unsafe { shakmaty::Square::new_unchecked(u32::from(self.0 & Self::SQ_MASK)) }
}

pub fn to(self) -> shakmaty::Square {
unsafe {
shakmaty::Square::new_unchecked(u32::from((self.0 >> Self::TO_SHIFT) & Self::SQ_MASK))
}
}

pub fn is_normal(self) -> bool {
self.0 & Self::ENPASSANT_FLAG == 0 && self.0 & Self::CASTLE_FLAG == 0
}

pub fn is_enpassant(self) -> bool {
self.0 & Self::ENPASSANT_FLAG != 0 && self.0 & Self::CASTLE_FLAG == 0
}

pub fn is_castle(self) -> bool {
self.0 & Self::CASTLE_FLAG != 0 && self.0 & Self::ENPASSANT_FLAG == 0
}

fn is_promotion(self) -> bool {
self.0 & Self::PROMO_FLAG == Self::PROMO_FLAG
}

pub fn promotion(self) -> Option<shakmaty::Role> {
if self.is_promotion() {
Some(promotion_idx_to_role(
(self.0 >> Self::PROMO_SHIFT) & Self::PROMO_MASK,
))
} else {
None
}
}

pub fn to_uci(self) -> String {
let from = self.from();
let to = if self.is_castle() && !is_chess960() {
match self.to() {
shakmaty::Square::H1 => shakmaty::Square::G1,
shakmaty::Square::A1 => shakmaty::Square::C1,
shakmaty::Square::H8 => shakmaty::Square::G8,
shakmaty::Square::A8 => shakmaty::Square::C8,
_ => panic!("Invalid castle move: {self:?}"),
}
} else {
self.to()
};

let promotion = self
.promotion()
.map_or(String::new(), role_to_promotion_char);

format!("{from}{to}{promotion}")
}

pub fn to_shakmaty(self, board: &shakmaty::Board) -> shakmaty::Move {
let from = self.from();
let to = self.to();

if self.is_enpassant() {
return shakmaty::Move::EnPassant { from, to };
}

if self.is_castle() {
return shakmaty::Move::Castle {
king: from,
rook: to,
};
}

let promotion = self.promotion();
let role = board.role_at(from).unwrap();
let capture = board.role_at(to);

shakmaty::Move::Normal {
from,
to,
promotion,
role,
capture,
}
}
}

impl From<shakmaty::Move> for Move {
fn from(m: shakmaty::Move) -> Self {
match m {
shakmaty::Move::Normal {
from,
to,
promotion: None,
..
} => Self::new(from, to),
shakmaty::Move::Normal {
from,
to,
promotion: Some(promo),
..
} => Self::new_promotion(from, to, promo),
shakmaty::Move::EnPassant { from, to } => Self::new_enpassant(from, to),
shakmaty::Move::Castle { king, rook } => Self::new_castle(king, rook),
_ => panic!("Invalid move: {m:?}"),
}
}
}

fn role_to_promotion_char(role: shakmaty::Role) -> String {
match role {
shakmaty::Role::Queen => "q",
shakmaty::Role::Rook => "r",
shakmaty::Role::Bishop => "b",
shakmaty::Role::Knight => "n",
_ => panic!("Invalid promotion role: {role:?}"),
}
.to_string()
}

fn role_to_promotion_idx(role: shakmaty::Role) -> u16 {
match role {
shakmaty::Role::Queen => 0,
shakmaty::Role::Rook => 1,
shakmaty::Role::Bishop => 2,
shakmaty::Role::Knight => 3,
_ => panic!("Invalid promotion role: {role:?}"),
}
}

fn promotion_idx_to_role(idx: u16) -> shakmaty::Role {
match idx {
0 => shakmaty::Role::Queen,
1 => shakmaty::Role::Rook,
2 => shakmaty::Role::Bishop,
3 => shakmaty::Role::Knight,
_ => panic!("Invalid promotion index: {idx}"),
}
}
5 changes: 3 additions & 2 deletions src/evaluation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use shakmaty::{MoveList, Position};
use shakmaty::Position;

use crate::chess::MoveList;
use crate::math;
use crate::search::SCALE;
use crate::state::{self, State};
Expand Down Expand Up @@ -164,7 +165,7 @@ fn run_policy_net(state: &State, moves: &MoveList, t: f32) -> Vec<f32> {
let mut acc = Vec::with_capacity(moves.len());

for m in moves {
let move_idx = state.move_to_index(m);
let move_idx = state.move_to_index(*m);
move_idxs.push(move_idx);
acc.push((
POLICY_NET.add_bias.vals[move_idx],
Expand Down
12 changes: 6 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ extern crate rand;
extern crate shakmaty;

mod arena;
mod args;
mod chess;
mod evaluation;
mod math;
mod options;
mod search;
mod search_tree;
mod state;
mod tablebase;
mod training;
mod transposition_table;
mod tree_policy;

mod args;
mod evaluation;
mod search;
mod state;
mod training;
mod uci;

fn main() {
Expand Down
21 changes: 11 additions & 10 deletions src/search.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use shakmaty::{CastlingMode, Color, Move};
use shakmaty::{CastlingMode, Color};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};

use crate::chess;
use crate::evaluation;
use crate::options::{get_num_threads, get_policy_temperature, is_chess960};
use crate::search_tree::{MoveEdge, SearchTree};
Expand Down Expand Up @@ -122,7 +123,7 @@ impl Search {
let mvs = state.available_moves();

if mvs.len() == 1 {
let uci_mv = to_uci(&mvs[0]);
let uci_mv = mvs[0].to_uci();
println!("info depth 1 seldepth 1 nodes 1 nps 1 tbhits 0 time 1 pv {uci_mv}");
println!("bestmove {uci_mv}");
return;
Expand Down Expand Up @@ -235,7 +236,7 @@ impl Search {
run_search_thread(&time_management);
self.search_tree.print_info(&time_management);
stop_signal.store(true, Ordering::Relaxed);
println!("bestmove {}", to_uci(&self.best_move().unwrap()));
println!("bestmove {}", self.best_move().to_uci());
});

for _ in 0..(num_threads - 1) {
Expand Down Expand Up @@ -280,8 +281,8 @@ impl Search {
moves.sort_by_key(|(h, e)| (h.average_reward().unwrap_or(*e) * SCALE) as i64);
for (mov, e) in moves {
println!(
"info string {:7} M: {:>6} P: {:>6} V: {:7} E: {:>6} ({:>8})",
format!("{}", mov.get_move()),
"info string {:7} M: {:5} P: {:>6} V: {:7} E: {:>6} ({:>8})",
format!("{}", mov.get_move().to_uci()),
format!("{:3.2}", e * 100.),
format!("{:3.2}", f32::from(mov.policy()) / SCALE * 100.),
mov.visits(),
Expand All @@ -293,21 +294,21 @@ impl Search {
}
}

pub fn principal_variation(&self, num_moves: usize) -> Vec<Move> {
pub fn principal_variation(&self, num_moves: usize) -> Vec<chess::Move> {
self.search_tree
.principal_variation(num_moves)
.into_iter()
.map(MoveEdge::get_move)
.cloned()
.copied()
.collect()
}

pub fn best_move(&self) -> Option<shakmaty::Move> {
self.principal_variation(1).get(0).cloned()
pub fn best_move(&self) -> chess::Move {
*self.principal_variation(1).get(0).unwrap()
}
}

pub fn to_uci(mov: &Move) -> String {
pub fn to_uci(mov: &shakmaty::Move) -> String {
mov.to_uci(CastlingMode::from_chess960(is_chess960()))
.to_string()
}
Expand Down
15 changes: 8 additions & 7 deletions src/search_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ use std::sync::atomic::{
};

use crate::arena::Error as ArenaError;
use crate::chess;
use crate::evaluation::{self, Flag};
use crate::options::{
get_cpuct, get_cpuct_root, get_cvisits_selection, get_policy_temperature,
get_policy_temperature_root,
};
use crate::search::{eval_in_cp, ThreadData};
use crate::search::{to_uci, TimeManagement, SCALE};
use crate::search::{TimeManagement, SCALE};
use crate::state::State;
use crate::transposition_table::{LRAllocator, LRTable, TranspositionTable};
use crate::tree_policy;
Expand Down Expand Up @@ -46,7 +47,7 @@ pub struct MoveEdge {
sum_evaluations: AtomicI64,
visits: AtomicU32,
policy: u16,
mov: shakmaty::Move,
mov: chess::Move,
child: AtomicPtr<PositionNode>,
}

Expand Down Expand Up @@ -95,7 +96,7 @@ impl PositionNode {
}

impl MoveEdge {
fn new(policy: u16, mov: shakmaty::Move) -> Self {
fn new(policy: u16, mov: chess::Move) -> Self {
Self {
policy,
sum_evaluations: AtomicI64::default(),
Expand All @@ -105,7 +106,7 @@ impl MoveEdge {
}
}

pub fn get_move(&self) -> &shakmaty::Move {
pub fn get_move(&self) -> &chess::Move {
&self.mov
}

Expand Down Expand Up @@ -193,7 +194,7 @@ where

#[allow(clippy::cast_sign_loss)]
for (i, x) in hots.iter_mut().enumerate() {
*x = MoveEdge::new((move_eval[i] * SCALE) as u16, moves[i].clone());
*x = MoveEdge::new((move_eval[i] * SCALE) as u16, moves[i]);
}

Ok(PositionNode::new(hots, state_flag))
Expand Down Expand Up @@ -303,7 +304,7 @@ impl SearchTree {
let choice = tree_policy::choose_child(node.hots(), cpuct, fpu);
choice.down();
path.push(choice);
state.make_move(&choice.mov);
state.make_move(choice.mov);

if choice.visits() == 1 {
evaln = evaluation::evaluate_state(&state);
Expand Down Expand Up @@ -453,7 +454,7 @@ impl SearchTree {
let sel_depth = self.max_depth();
let pv = self.principal_variation(depth.max(1));
let pv_string: String = pv.into_iter().fold(String::new(), |mut out, x| {
write!(out, " {}", to_uci(x.get_move())).unwrap();
write!(out, " {}", x.get_move().to_uci()).unwrap();
out
});

Expand Down
Loading

0 comments on commit 018545a

Please sign in to comment.