Source code for blackbox_mpc.dynamics_functions.deterministic_mlp

import tensorflow as tf


[docs]class DeterministicMLP(tf.Module):
[docs] def __init__(self, layers, activation_functions, loss_fn=tf.keras.losses.MeanSquaredError(), name=None): """ A deterministic fully connected MLP dynamics function class for (s_t, a_t) - > (s_t+1) Parameters --------- name: str Defines the name of the block of the deterministic MLP dynamics function. """ super(DeterministicMLP, self).__init__(name=name) self.layers = [] for i in range(1, len(layers)): self.layers.append( tf.keras.layers.Dense(layers[i], input_shape=(layers[i-1],), activation=activation_functions[i-1], dtype='float32')) self.loss_fn = loss_fn
[docs] @tf.function def __call__(self, x, train): """ This is the call function for the deterministic fully connected MLP dynamics function class. Parameters --------- x: tf.float32 Defines the (s_t, a_t) which is the state and action stacked on top of each other, (dims = Batch X (dim_S + dim_U)) train: tf.bool Defines whether the network should run in train mode or not. Returns ------- output: tf.float32 Defines the next state (s_t+1) with (dims = Batch X dim_S) """ for layer in self.layers.layers: x = layer(x) return x
[docs] @tf.function def get_loss(self, expected_output, predictions): """ This is the training loss function for the deterministic fully connected MLP dynamics function class. Parameters --------- expected_output: tf.float32 Defines the next state (s_t+1) ground truth with (dims = Batch X dim_S) predictions: tf.float32 Defines the next state (s_t+1) predicted values with (dims = Batch X dim_S) Returns ------- train_loss: tf.float32 Defines the training loss as a scalar for the whole batch """ return self.loss_fn(expected_output, predictions)
[docs] @tf.function def get_validation_loss(self, expected_output, predictions): """ This is the validation loss function for the deterministic fully connected MLP dynamics function class. Parameters --------- expected_output: tf.float32 Defines the next state (s_t+1) ground truth with (dims = Batch X dim_S) predictions: tf.float32 Defines the next state (s_t+1) predicted values with (dims = Batch X dim_S) Returns ------- validation_loss: tf.float32 Defines the validation loss as a scalar for the whole batch """ return self.loss_fn(expected_output, predictions)