Skip to content

Commit

Permalink
remove divergence computation, fix TC evaluation (#171)
Browse files Browse the repository at this point in the history
divergence computation significantly reduces speed of simulation and
there's little evidence to suggest it's actually valuable, TC eval was
just incorrect.

this fixes the bad play observed in game
4eec9c45-cacb-4d34-b5c8-62ae4011afd0 where the bot held AS pretty much
guaranteeing it would take the charged QS.

Co-authored-by: Tim Wilson <[email protected]>
  • Loading branch information
tjwilson90 and twilson-palantir authored Apr 16, 2021
1 parent c2fcc91 commit e6df5af
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 94 deletions.
2 changes: 1 addition & 1 deletion api/src/game.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
Seat, UserId,
};

#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct Game<S> {
pub events: Vec<GameEvent>,
pub subscribers: Vec<(UserId, S)>,
Expand Down
2 changes: 1 addition & 1 deletion api/src/seed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl Seed {
}
}

#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct HashedSeed {
seed: [u8; 32],
}
Expand Down
54 changes: 19 additions & 35 deletions bot/examples/play_one.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use turbo_hearts_api::{
};
use turbo_hearts_bot::{Algorithm, NeuralNetworkBot};

#[derive(Clone)]
struct State {
game: Game<()>,
bot_state: BotState,
Expand Down Expand Up @@ -85,53 +86,36 @@ fn main() {
.filter_module("turbo_hearts_bot", LevelFilter::Debug)
.is_test(true)
.try_init();
let mut state = State::new(Seat::South, PassDirection::Across);
let mut state = State::new(Seat::North, PassDirection::Right);
state.deal([
"J43S 8765H 83D AK74C".parse().unwrap(),
"KT82S KQ942H 42D 63C".parse().unwrap(),
"Q65S T3H JT65D Q852C".parse().unwrap(),
"A97S AJH AKQ97D JT9C".parse().unwrap(),
"QT64S Q982H K74D 86C".parse().unwrap(),
"A975S J653H AJ3D AJC".parse().unwrap(),
"KH 9852D QT975432C".parse().unwrap(),
"KJ832S AT74H QT6D KC".parse().unwrap(),
]);
state.send_pass(Seat::North, "83D AC".parse().unwrap());
state.send_pass(Seat::East, "KS KQH".parse().unwrap());
state.send_pass(Seat::South, "QS T3H".parse().unwrap());
state.send_pass(Seat::West, "7S JH JC".parse().unwrap());
state.send_pass(Seat::North, "QS 9H 8C".parse().unwrap());
state.send_pass(Seat::East, "AS AJC".parse().unwrap());
state.send_pass(Seat::South, "852D".parse().unwrap());
state.send_pass(Seat::West, "TH TD KC".parse().unwrap());

state.recv_pass(Seat::North, "QS T3H".parse().unwrap());
state.recv_pass(Seat::East, "7S JH JC".parse().unwrap());
state.recv_pass(Seat::South, "83D AC".parse().unwrap());
state.recv_pass(Seat::West, "KS KQH".parse().unwrap());
state.recv_pass(Seat::North, "AS AJC".parse().unwrap());
state.recv_pass(Seat::East, "852D".parse().unwrap());
state.recv_pass(Seat::South, "TH TD KC".parse().unwrap());
state.recv_pass(Seat::West, "QS 9H 8C".parse().unwrap());

state.charge(Seat::West, "TC".parse().unwrap());
state.charge(Seat::South, "TC".parse().unwrap());
state.charge(Seat::West, "QS".parse().unwrap());
state.charge(Seat::North, "".parse().unwrap());
state.charge(Seat::East, "".parse().unwrap());
state.charge(Seat::South, "".parse().unwrap());

for c in &["2C", "9C", "7C", "6C", "8C", "TC", "4C", "3C"] {
for c in &["2C", "8C", "AC", "7S"] {
state.play(c.parse().unwrap());
}
for c in &["AD", "QS", "4D", "3D"] {
for c in &["4S", "5S", "9D", "JS"] {
state.play(c.parse().unwrap());
}
for c in &["AS", "JS", "8S", "6S"] {
state.play(c.parse().unwrap());
}
for c in &["KD", "KC", "2D", "TD"] {
state.play(c.parse().unwrap());
}
for c in &["9D", "4S", "JC", "5D", "QD", "3S", "7S", "6D"] {
state.play(c.parse().unwrap());
}
for c in &["KS", "8H", "2S", "5S"] {
state.play(c.parse().unwrap());
}
for c in &["AH", "7H", "4H", "AC"] {
state.play(c.parse().unwrap());
}
for c in &["KH", "6H", "2H", "8D"] {
state.play(c.parse().unwrap());
}
for c in &["QH", "5H", "JH"] {
for c in &["8S", "6S", "9S", "KH", "3S"] {
state.play(c.parse().unwrap());
}
println!("{}", state.bot.play(&state.bot_state, &state.game.state));
Expand Down
71 changes: 16 additions & 55 deletions bot/src/neural_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ fn load_model(lead: bool, policy: bool) -> Result<TypedRunnableModel<TypedModel>
Ok(model.into_optimized()?.into_runnable()?)
}

#[derive(Clone)]
pub struct NeuralNetworkBot {
hand_maker: HandMaker,
initial_state: GameState,
Expand All @@ -88,23 +89,6 @@ impl NeuralNetworkBot {
plays: Vec::with_capacity(52),
}
}

fn total_divergence(&self, hands: [Cards; 4]) -> f32 {
let brute_force = ShallowBruteForce::new(hands);
let mut state = self.initial_state.clone();
let mut divergence = 1.0;
for &play in &self.plays {
let seat = state.next_actor.unwrap();
let plays = state
.legal_plays(hands[seat.idx()])
.distinct_plays(state.played, state.current_trick);
if plays.len() > 1 {
divergence += local_divergence(&brute_force, &state, seat, play, plays);
}
state.apply(&GameEvent::Play { seat, card: play });
}
divergence
}
}

impl Algorithm for NeuralNetworkBot {
Expand All @@ -129,19 +113,27 @@ impl Algorithm for NeuralNetworkBot {
while now.elapsed().as_millis() < 4500 {
iters += 1;
let hands = self.hand_maker.make();
let divergence = self.total_divergence(hands);
let brute_force = ShallowBruteForce::new(hands);
for card in distinct_plays {
let mut game = game_state.clone();
game.apply(&GameEvent::Play {
seat: bot_state.seat,
card,
});
let mut brute_force = ShallowBruteForce::new(hands);
let scores = brute_force.solve(&mut game);
*money_counts.entry(card).or_default() += scores.money(bot_state.seat) / divergence;
*money_counts.entry(card).or_default() += scores.money(bot_state.seat);
}
}
debug!("{} iterations, {:?}", iters, money_counts);
if log::log_enabled!(log::Level::Debug) {
debug!(
"{} iterations, {:?}",
iters,
money_counts
.iter()
.map(|(k, v)| (k, *v / iters as f32))
.collect::<HashMap<_, _>>()
);
}
let mut best_card = Card::TwoClubs;
let mut best_money = f32::MIN;
for (card, money) in money_counts.into_iter() {
Expand Down Expand Up @@ -176,7 +168,7 @@ impl ShallowBruteForce {
Self { hands }
}

fn solve(&mut self, state: &mut GameState) -> ApproximateScores {
fn solve(&self, state: &mut GameState) -> ApproximateScores {
if state.played.len() >= 48 {
while state.played != Cards::ALL {
let seat = state.next_actor.unwrap();
Expand Down Expand Up @@ -393,6 +385,7 @@ impl ShallowBruteForce {
}
}

#[derive(Debug)]
struct ApproximateScores {
scores: [f32; 4],
}
Expand Down Expand Up @@ -475,7 +468,7 @@ impl ApproximateScores {
1.0
};
if let Some(s) = state.won.ten_winner() {
scores[s.idx()] *= tf * 2.0;
scores[s.idx()] *= 1.0 + tf;
} else {
let ten = output[2].as_slice::<f32>().unwrap();
scores[0] *= 1.0 + tf * ten[north];
Expand Down Expand Up @@ -506,35 +499,3 @@ fn choose(bot_state: &BotState, card: Card, legal_plays: Cards, distinct_plays:
let index = rand::thread_rng().gen_range(0..cards.len());
cards.into_iter().nth(index).unwrap()
}

fn local_divergence(
brute_force: &ShallowBruteForce,
state: &GameState,
seat: Seat,
play: Card,
plays: Cards,
) -> f32 {
let equivalent = if plays.contains(play) {
play
} else {
plays.above(play).min()
};
let mut best_money = f32::MIN;
for card in plays - equivalent {
let scores = brute_force.generate_value(&mut {
let mut state = state.clone();
state.apply(&GameEvent::Play { seat, card });
state
});
let money = scores.money(seat);
if money > best_money {
best_money = money;
}
}
let scores = brute_force.generate_value(&mut {
let mut state = state.clone();
state.apply(&GameEvent::Play { seat, card: play });
state
});
f32::max(best_money - scores.money(seat), 0.0)
}
4 changes: 2 additions & 2 deletions neural_net.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# python3 -m venv --system-site-packages ./turbo-hearts-venv
# source ./turbo-hearts-venv/bin/activate
# pip install --upgrade pip
# pip install --upgrade tensorflow keras-tuner onnx onnxmltools
# pip install --upgrade tensorflow==2.2.0 keras-tuner onnx onnxmltools

import os
import tensorflow as tf
Expand Down Expand Up @@ -90,7 +90,7 @@ def build_model(lead, policy, hp):
layer = keras.layers.concatenate(inputs)
layer = keras.layers.Dropout(0.05)(layer)

for i in range(1 if hp is None else hp.Int('num_layers', 2, 4)):
for i in range(2 if hp is None else hp.Int('num_layers', 2, 4)):
units = [500, 500][i] if hp is None else hp.Int('units' + str(i), min_value=384, max_value=576, step=64)
layer = keras.layers.Dense(units = units, activation = 'relu')(layer)
layer = keras.layers.Dropout(0.05)(layer)
Expand Down

0 comments on commit e6df5af

Please sign in to comment.