Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
tigerros committed Sep 22, 2024
1 parent c60e7bb commit 3773554
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 82 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on: [push, pull_request]

jobs:
test:
runs-on: ubuntu-latest
runs-on: windows-latest

steps:
- name: Checkout code
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ shakmaty = "0.27.1"
paste = "1.0.15"
dry-mods = "0.1.5"
parking_lot = { version = "0.12.3", features = ["arc_lock", "send_guard"], optional = true }
tokio = { version = "1.39.2", features = ["rt", "sync", "io-std", "io-util", "process"], optional = true }
tokio = { version = "1.40.0", features = ["rt", "sync", "io-std", "io-util", "process"], optional = true }

[dev-dependencies]
pretty_assertions = "1.4.0"
Expand Down
13 changes: 7 additions & 6 deletions examples/go.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@
use ruci::messages::{GoMessage, GuiMessage};
use ruci::EngineConnection;
use std::io;
use tokio::io::AsyncBufReadExt;

#[tokio::main]
async fn main() -> io::Result<()> {
let mut engine_conn = EngineConnection::from_path("stockfish").unwrap();

println!("== Sending use UCI message, waiting for uciok");

let (id, options) = engine_conn.use_uci().await?;

println!("== Received uciok");
println!("== ID: {id:?}");
println!("== Options: {options:?}");
println!("== Sending isready message, waiting for readyok");

let (infos, best_move) = engine_conn
.go(GoMessage {
search_moves: None,
Expand All @@ -38,13 +39,13 @@ async fn main() -> io::Result<()> {
infinite: false,
})
.await?;

for info in infos {
println!("Info: {info:?}");
}

println!("Best move: {best_move:?}");

println!("== Sending quit message");
engine_conn.send_message(&GuiMessage::Quit).await?;
println!("== Sent. Program terminated");
Expand Down
Binary file added resources/stockfish.exe
Binary file not shown.
157 changes: 83 additions & 74 deletions src/uci_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::marker::PhantomData;
use std::process::Stdio;
use std::sync::Arc;
use tokio::io;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -41,7 +41,7 @@ where
MReceive: Message,
{
pub process: Child,
pub stdout: ChildStdout,
pub stdout: BufReader<ChildStdout>,
pub stdin: ChildStdin,
_phantom: PhantomData<(MSend, MReceive)>,
}
Expand All @@ -51,34 +51,37 @@ where
MSend: Message,
MReceive: Message,
{
pub const fn new(process: Child, stdout: ChildStdout, stdin: ChildStdin) -> Self {
Self {
process,
stdout,
stdin,
_phantom: PhantomData,
}
}

/// # Errors
///
/// [`UciCreationError::Spawn`] is guaranteed not to occur here.
pub fn from_process(mut process: Child) -> Result<Self, UciCreationError> {
let Some(stdout) = process.stdout.take() else {
return Err(UciCreationError::StdoutIsNone);
};

let Some(stdin) = process.stdin.take() else {
return Err(UciCreationError::StdinIsNone);
};

Ok(Self {
process,
stdout,
stdin,
_phantom: PhantomData,
})
}
// pub const fn new(process: Child, stdout: BufReader<ChildStdout>, stdin: BufWriter<ChildStdin>) -> Self {
// Self {
// process,
// stdout,
// stdin,
// _phantom: PhantomData,
// }
// }
//
// /// # Errors
// ///
// /// [`UciCreationError::Spawn`] is guaranteed not to occur here.
// pub fn from_process(mut process: Child) -> Result<Self, UciCreationError> {
// let Some(stdout) = process.stdout.take() else {
// return Err(UciCreationError::StdoutIsNone);
// };
//
// let Some(stdin) = process.stdin.take() else {
// return Err(UciCreationError::StdinIsNone);
// };
//
// let stdout = BufReader::new(stdout);
// let stdin = BufWriter::new(stdin);
//
// Ok(Self {
// process,
// stdout,
// stdin,
// _phantom: PhantomData,
// })
// }

/// Creates a new UCI connection from the given executable path.
///
Expand All @@ -101,11 +104,13 @@ where
let Some(stdout) = process.stdout.take() else {
return Err(UciCreationError::StdoutIsNone);
};

let Some(stdin) = process.stdin.take() else {
return Err(UciCreationError::StdinIsNone);
};

let stdout = BufReader::new(stdout);

Ok(Self {
process,
stdout,
Expand All @@ -129,27 +134,30 @@ where
///
/// See [`Read::read_exact`].
pub async fn skip_lines(&mut self, count: usize) -> io::Result<()> {
let mut buf = [0; 1];
let mut skipped_count = 0;

loop {
self.stdout.read_exact(&mut buf).await?;

if buf[0] == b'\n' {
// CLIPPY: `skipped_count` never overflows because it starts at 0, increments by 1, and stops once `count` is reached.
#[allow(clippy::arithmetic_side_effects)]
{
skipped_count += 1;
}

if skipped_count == count {
break;
}

continue;
}
let mut buf = String::new();

for _ in 0..count {
self.stdout.read_line(&mut buf).await?;
}

// loop {
// self.stdout.read_exact(&mut buf).await?;
//
// if buf[0] == b'\n' {
// // CLIPPY: `skipped_count` never overflows because it starts at 0, increments by 1, and stops once `count` is reached.
// #[allow(clippy::arithmetic_side_effects)]
// {
// skipped_count += 1;
// }
//
// if skipped_count == count {
// break;
// }
//
// continue;
// }
// }

Ok(())
}

Expand All @@ -162,32 +170,12 @@ where
pub async fn read_message(
&mut self,
) -> Result<MReceive, UciReadMessageError<MReceive::ParameterPointer>> {
MReceive::from_str(&self.read_line().await.map_err(UciReadMessageError::Io)?)
let mut line = String::new();
self.stdout.read_line(&mut line).await.map_err(UciReadMessageError::Io)?;

MReceive::from_str(&line)
.map_err(UciReadMessageError::MessageParse)
}

/// Reads one line without the trailing `'\n'` character.
///
/// # Errors
///
/// - Reading resulted in an IO error.
/// - Parsing the message errors.
pub async fn read_line(&mut self) -> io::Result<String> {
let mut s = String::with_capacity(100);
let mut buf = [0; 1];

loop {
self.stdout.read_exact(&mut buf).await?;

if buf[0] == b'\n' {
break;
}

s.push(char::from(buf[0]));
}

Ok(s)
}
}

impl EngineConnection {
Expand Down Expand Up @@ -348,3 +336,24 @@ fn update_id(old_id: &mut Option<IdMessageKind>, new_id: IdMessageKind) {
}
});
}

#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;

#[tokio::test]
async fn skip_lines() {
let mut engine_conn = EngineConnection::from_path("/resources/stockfish.exe").unwrap();

engine_conn.send_message(&GuiMessage::UseUci).await.unwrap();

engine_conn.skip_lines(4).await.unwrap();

let mut line = String::new();
engine_conn.stdout.read_line(&mut line).await.unwrap();

assert_eq!(line, "option name Debug Log File type string default\n");
}
}

0 comments on commit 3773554

Please sign in to comment.