use crate::game::{Base, Game};
use crate::policies::MultiplayerPolicy;
use async_trait::async_trait;
use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::FromIterator;
use std::marker::PhantomData;
pub mod muz;
pub mod puct;
pub mod rave;
pub mod uct;
pub trait MCTSGame = Game + Clone;
use std::sync::RwLock;
use std::sync::{Arc, Weak};
pub type MCTSNodeParent<G, MCTS> = Option<(Weak<RwLock<MCTSTreeNode<G, MCTS>>>, <G as Base>::Move)>;
pub type MCTSNodeChild<G, MCTS> = Arc<RwLock<MCTSTreeNode<G, MCTS>>>;
#[derive(Clone)]
pub struct MCTSTreeNode<G, MCTS>
where
G: MCTSGame,
MCTS: BaseMCTSPolicy<G>,
{
pub parent: MCTSNodeParent<G, MCTS>,
pub moves: HashMap<G::Move, MCTSNodeChild<G, MCTS>>,
pub info: MCTSNode<G, MCTS>,
}
impl<G, MCTS> Debug for MCTSTreeNode<G, MCTS>
where
G: MCTSGame,
MCTS: BaseMCTSPolicy<G>,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
writeln!(fmt, "{:?} ===> {:?}", self.moves, self.info)
}
}
#[derive(Clone)]
pub struct MCTSNode<G: MCTSGame, MCTS: BaseMCTSPolicy<G>> {
pub state: G,
pub reward: f32,
pub node: MCTS::NodeInfo,
pub moves: HashMap<G::Move, MCTS::MoveInfo>,
}
impl<G, MCTS> Debug for MCTSNode<G, MCTS>
where
G: MCTSGame,
MCTS: BaseMCTSPolicy<G>,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
write!(fmt, "NODE: {:?}|| MOVES:{:?}", self.node, self.moves)
}
}
#[async_trait]
pub trait BaseMCTSPolicy<G: MCTSGame>: Sized {
type NodeInfo: Debug + Clone + Copy + Send + Sync;
type MoveInfo: Debug + Clone + Copy + Send + Sync;
type PlayoutInfo: Send + Sync;
fn get_value(
&self,
board: &G,
action: &G::Move,
node_info: &Self::NodeInfo,
move_info: &Self::MoveInfo,
exploration: bool,
) -> f32;
fn default_node(&self, board: &G) -> Self::NodeInfo;
fn default_move(&self, board: &G, action: &G::Move) -> Self::MoveInfo;
fn backpropagate(
&mut self,
leaf: MCTSNodeChild<G, Self>,
history: &[G::Move],
playout: Self::PlayoutInfo,
);
async fn simulate(&self, board: &G) -> Self::PlayoutInfo;
}
use float_ord::FloatOrd;
pub struct WithMCTSPolicy<G, MCTS>
where
G: MCTSGame,
MCTS: BaseMCTSPolicy<G>,
{
pub base_mcts: MCTS,
N_PLAYOUTS: usize,
pub root: Option<MCTSNodeChild<G, MCTS>>,
_g: std::marker::PhantomData<G>,
}
impl<G, MCTS> WithMCTSPolicy<G, MCTS>
where
G: MCTSGame + Clone,
MCTS: BaseMCTSPolicy<G>,
{
fn select_move(&self, tree_node: &MCTSTreeNode<G, MCTS>, exploration: bool) -> G::Move {
*tree_node
.info
.moves
.iter()
.map(|(action, move_info)| {
(
action,
self.base_mcts.get_value(
&tree_node.info.state,
action,
&tree_node.info.node,
&move_info,
exploration,
),
)
})
.max_by_key(|x| FloatOrd(x.1))
.unwrap()
.0
}
fn select(&self, root: MCTSNodeChild<G, MCTS>) -> (Vec<G::Move>, MCTSNodeChild<G, MCTS>) {
let mut history: Vec<G::Move> = Vec::new();
let mut last_node = root;
loop {
let last_node_clone = last_node.clone();
let last_node_ref = last_node_clone.read().unwrap();
if last_node_ref.info.state.is_finished() {
return (history, last_node);
} else {
let a = self.select_move(&last_node_ref, true);
history.push(a);
let node_imm = last_node_ref.moves.get(&a);
if let Some(node) = node_imm {
if node.read().unwrap().info.state.is_finished() {
return (history, last_node);
} else {
let node = last_node_ref.moves.get(&a).unwrap();
last_node = node.clone();
}
} else {
return (history, last_node);
}
}
}
}
async fn expand(
&mut self,
tree_node: MCTSNodeChild<G, MCTS>,
action: &G::Move,
) -> MCTSNodeChild<G, MCTS> {
let mut new_state = tree_node.read().unwrap().info.state.clone();
let reward = new_state.play(action).await;
let new_node = self.base_mcts.default_node(&new_state);
let moves_info = HashMap::from_iter(
new_state
.possible_moves()
.iter()
.map(|m| (*m, self.base_mcts.default_move(&new_state, &m))),
);
tree_node.write().unwrap().moves.insert(
*action,
Arc::new(RwLock::new(MCTSTreeNode {
parent: Some((Arc::downgrade(&tree_node), *action)),
moves: HashMap::new(),
info: MCTSNode {
reward,
moves: moves_info,
node: new_node,
state: new_state,
},
})),
);
tree_node.read().unwrap().moves.get(action).unwrap().clone()
}
async fn tree_search(&mut self, root: MCTSNodeChild<G, MCTS>) {
let (history, last_node) = self.select(root);
let created_node = self.expand(last_node, history.last().unwrap()).await;
let state = created_node.read().unwrap().info.state.clone();
let playout = self.base_mcts.simulate(&state).await;
self.base_mcts
.backpropagate(created_node, &history, playout);
}
pub fn new(p: MCTS, N_PLAYOUTS: usize) -> Self {
WithMCTSPolicy {
base_mcts: p,
N_PLAYOUTS,
root: None,
_g: PhantomData,
}
}
}
#[async_trait]
impl<G, MCTS> MultiplayerPolicy<G> for WithMCTSPolicy<G, MCTS>
where
G: MCTSGame,
MCTS: BaseMCTSPolicy<G> + Sync + Send,
{
async fn play(&mut self, board: &G) -> G::Move {
let root = Arc::new(RwLock::new(MCTSTreeNode {
parent: None,
info: MCTSNode {
reward: 0.,
state: board.clone(),
node: self.base_mcts.default_node(board),
moves: HashMap::from_iter(
board
.possible_moves()
.iter()
.map(|m| (*m, self.base_mcts.default_move(board, m))),
),
},
moves: HashMap::new(),
}));
let playout = self.base_mcts.simulate(board).await;
self.base_mcts.backpropagate(root.clone(), &[], playout);
for _ in 0..self.N_PLAYOUTS {
self.tree_search(root.clone()).await
}
let chosen_move = self.select_move(&root.read().unwrap(), false);
self.root = Some(root);
chosen_move
}
}