1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
use std::path::Path;
use std::sync::{atomic::AtomicBool, Arc, RwLock};
use tensorflow::{Graph, Session, SessionOptions, SessionRunArgs, Tensor};
pub type ThreadSafeModel = Arc<(AtomicBool, RwLock<(Graph, Session)>)>;
fn sign(x: f32) -> f32 {
if x > 0. {
1.
} else if x == 0. {
0.
} else {
-1.
}
}
pub fn support_to_value(
support: &Tensor<f32>,
batch_size: usize,
support_size: usize,
) -> Tensor<f32> {
let mut res = Tensor::new(&[batch_size as u64]);
for i in 0..batch_size {
let value: f32 = (-(support_size as isize)..(support_size as isize + 1))
.enumerate()
.map(|(j, v)| support[(2 * support_size + 1) * i + j] * (v as f32))
.sum();
let value: f32 = sign(value)
* ((((1. + 4. * 0.001 * (value.abs() + 1. + 0.001)).sqrt() - 1.) / (2. * 0.001))
.powi(2)
- 1.);
res[i] = value;
}
res
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_support_to_value() {
let mut support = Tensor::new(&[1, 3]);
support[0] = 1.0;
support[1] = 0.;
support[2] = 0.;
println!("=> {:?}", support_to_value(&support, 1, 1).to_vec());
support[0] = 0.;
support[1] = 1.;
support[2] = 0.;
println!("=> {:?}", support_to_value(&support, 1, 1).to_vec());
support[0] = 0.;
support[1] = 0.;
support[2] = 1.;
println!("=> {:?}", support_to_value(&support, 1, 1).to_vec());
}
}
pub fn call_prediction(
session: &Session,
graph: &Graph,
board: &Tensor<f32>,
) -> (Tensor<f32>, Tensor<f32>) {
let board_op = graph
.operation_by_name_required("serving_default_board")
.unwrap();
let output_op = graph
.operation_by_name_required("StatefulPartitionedCall")
.unwrap();
let mut args = SessionRunArgs::new();
args.add_feed(&board_op, 0, board);
let policy_req = args.request_fetch(&output_op, 0);
let value_req = args.request_fetch(&output_op, 1);
session.run(&mut args).unwrap();
let policy_tensor: Tensor<f32> = args.fetch(policy_req).unwrap();
let value_tensor: Tensor<f32> = args.fetch(value_req).unwrap();
(policy_tensor, value_tensor)
}
pub fn call_dynamics(
session: &Session,
graph: &Graph,
board: &Tensor<f32>,
action: &Tensor<f32>,
) -> (Tensor<f32>, Tensor<f32>) {
let board_op = graph
.operation_by_name_required("serving_default_board")
.unwrap();
let action_op = graph
.operation_by_name_required("serving_default_action")
.unwrap();
let output_op = graph
.operation_by_name_required("StatefulPartitionedCall")
.unwrap();
let mut args = SessionRunArgs::new();
args.add_feed(&board_op, 0, board);
args.add_feed(&action_op, 0, action);
let reward_req = args.request_fetch(&output_op, 1);
let next_board_req = args.request_fetch(&output_op, 0);
session.run(&mut args).unwrap();
let reward_tensor: Tensor<f32> = args.fetch(reward_req).unwrap();
let next_board_tensor: Tensor<f32> = args.fetch(next_board_req).unwrap();
(reward_tensor, next_board_tensor)
}
pub fn call_representation(session: &Session, graph: &Graph, board: &Tensor<f32>) -> Tensor<f32> {
let board_op = graph
.operation_by_name_required("serving_default_board")
.unwrap();
let output_op = graph
.operation_by_name_required("StatefulPartitionedCall")
.unwrap();
let mut args = SessionRunArgs::new();
args.add_feed(&board_op, 0, board);
let repr_board_req = args.request_fetch(&output_op, 0);
session.run(&mut args).unwrap();
let repr_board_tensor: Tensor<f32> = args.fetch(repr_board_req).unwrap();
repr_board_tensor
}
pub fn load_model(path: &str) -> (Graph, Session) {
if !Path::new(path).exists() {
log::error!("Couldn't find model at {}", path);
panic!("");
};
let mut graph = Graph::new();
let mut options = SessionOptions::new();
let configuration_buf = [50, 2, 32, 1];
options.set_config(&configuration_buf).unwrap();
let session = Session::from_saved_model(&options, &["serve"], &mut graph, path).unwrap();
(graph, session)
}