1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
use std::path::Path;
use std::sync::{atomic::AtomicBool, Arc, RwLock};
use tensorflow::{Graph, Session, SessionOptions, SessionRunArgs, Tensor};

/// Access to a TF model behind Arc and RwLock
/// the AtomicBool is here to indicate the file loader's intention
/// to access the lock.

pub type ThreadSafeModel = Arc<(AtomicBool, RwLock<(Graph, Session)>)>;

fn sign(x: f32) -> f32 {
    if x > 0. {
        1.
    } else if x == 0. {
        0.
    } else {
        -1.
    }
}

/// Converts a suport encoding of scalar to the corresponding value.
pub fn support_to_value(
    support: &Tensor<f32>,
    batch_size: usize,
    support_size: usize,
) -> Tensor<f32> {
    let mut res = Tensor::new(&[batch_size as u64]);

    for i in 0..batch_size {
        let value: f32 = (-(support_size as isize)..(support_size as isize + 1))
            .enumerate()
            .map(|(j, v)| support[(2 * support_size + 1) * i + j] * (v as f32))
            .sum();
        let value: f32 = sign(value)
            * ((((1. + 4. * 0.001 * (value.abs() + 1. + 0.001)).sqrt() - 1.) / (2. * 0.001))
                .powi(2)
                - 1.);

        res[i] = value;
    }
    res
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn test_support_to_value() {
        let mut support = Tensor::new(&[1, 3]);

        support[0] = 1.0;
        support[1] = 0.;
        support[2] = 0.;
        println!("=> {:?}", support_to_value(&support, 1, 1).to_vec());

        support[0] = 0.;
        support[1] = 1.;
        support[2] = 0.;
        println!("=> {:?}", support_to_value(&support, 1, 1).to_vec());

        support[0] = 0.;
        support[1] = 0.;
        support[2] = 1.;
        println!("=> {:?}", support_to_value(&support, 1, 1).to_vec());
    }
}

/// Use prediction network inference.
pub fn call_prediction(
    session: &Session,
    graph: &Graph,
    board: &Tensor<f32>,
) -> (Tensor<f32>, Tensor<f32>) {
    let board_op = graph
        .operation_by_name_required("serving_default_board")
        .unwrap();
    let output_op = graph
        .operation_by_name_required("StatefulPartitionedCall")
        .unwrap();
    let mut args = SessionRunArgs::new();
    args.add_feed(&board_op, 0, board);

    let policy_req = args.request_fetch(&output_op, 0);
    let value_req = args.request_fetch(&output_op, 1);
    session.run(&mut args).unwrap();

    let policy_tensor: Tensor<f32> = args.fetch(policy_req).unwrap();
    let value_tensor: Tensor<f32> = args.fetch(value_req).unwrap();
    (policy_tensor, value_tensor)
}

/// Use dynamics network inference.
pub fn call_dynamics(
    session: &Session,
    graph: &Graph,
    board: &Tensor<f32>,
    action: &Tensor<f32>,
) -> (Tensor<f32>, Tensor<f32>) {
    let board_op = graph
        .operation_by_name_required("serving_default_board")
        .unwrap();
    let action_op = graph
        .operation_by_name_required("serving_default_action")
        .unwrap();
    let output_op = graph
        .operation_by_name_required("StatefulPartitionedCall")
        .unwrap();
    let mut args = SessionRunArgs::new();
    args.add_feed(&board_op, 0, board);
    args.add_feed(&action_op, 0, action);

    let reward_req = args.request_fetch(&output_op, 1);
    let next_board_req = args.request_fetch(&output_op, 0);
    session.run(&mut args).unwrap();

    let reward_tensor: Tensor<f32> = args.fetch(reward_req).unwrap();
    let next_board_tensor: Tensor<f32> = args.fetch(next_board_req).unwrap();
    (reward_tensor, next_board_tensor)
}

/// Use representation network inference.
pub fn call_representation(session: &Session, graph: &Graph, board: &Tensor<f32>) -> Tensor<f32> {
    let board_op = graph
        .operation_by_name_required("serving_default_board")
        .unwrap();
    let output_op = graph
        .operation_by_name_required("StatefulPartitionedCall")
        .unwrap();
    let mut args = SessionRunArgs::new();
    args.add_feed(&board_op, 0, board);

    let repr_board_req = args.request_fetch(&output_op, 0);
    session.run(&mut args).unwrap();

    let repr_board_tensor: Tensor<f32> = args.fetch(repr_board_req).unwrap();
    repr_board_tensor
}

/// Load a tensorflow model into a session.
pub fn load_model(path: &str) -> (Graph, Session) {
    /* check that model exists. */
    if !Path::new(path).exists() {
        log::error!("Couldn't find model at {}", path);
        panic!("");
    };

    let mut graph = Graph::new();
    let mut options = SessionOptions::new();
    /* To get configuration, use python:
     *      config = tf.ConfigProto()
     *      config.gpu_options.allow_growth = True
     *      config.SerializeToString()
     */
    let configuration_buf = [50, 2, 32, 1];
    options.set_config(&configuration_buf).unwrap();
    let session = Session::from_saved_model(&options, &["serve"], &mut graph, path).unwrap();
    (graph, session)
}