Source code for blackbox_mpc.utils.dynamics_learning

import tensorflow as tf
from blackbox_mpc.utils.rollouts import perform_rollouts
from blackbox_mpc.dynamics_handlers.system_dynamics_handler import \
    SystemDynamicsHandler


[docs]def learn_dynamics_from_policy(env, policy, number_of_rollouts, task_horizon, dynamics_function=None, system_dynamics_handler=None, epochs=30, learning_rate=1e-3, validation_split=0.2, batch_size=128, is_normalized=True, nn_optimizer=tf.keras.optimizers.Adam, tf_writer=None, exploration_noise=False, log_dir=None, save_model_frequency=1, saved_model_dir=None, start_episode=0): """ This is the learn dynamics function for the runner class which samples n rollouts using a random policy and then uses these rollouts to learn a dynamics function for the system. Parameters --------- env: parallelgymEnv a wrapped gym environment using blackbox.environment_utils.EnvironmentWrapper funcs policy: ModelFreeBasePolicy or ModelBasedBasePolicy the policy used for learning the dynamics. number_of_rollouts: Int Number of rollouts/ episodes to perform for each of the agents in the vectorized environment. task_horizon: Int The task horizon/ episode length. dynamics_function: DeterministicDynamicsFunctionBaseClass Defines the system dynamics function. learning_rate: float Learning rate to be used in training the dynamics function. epochs: Int Number of epochs to be used in training the dynamics function everytime train is called. validation_split: float32 Defines the validation split to be used of the rollouts collected. batch_size: int Defines the batch size to be used for training the model. nn_optimizer: tf.keras.optimizers Defines the optimizer to use with the neural network. is_normalized: bool Defines if the dynamics function should be trained with normalization or not. log_dir: string Defines the log directory to save the normalization statistics in. tf_writer: tf.summary Tensorflow writer to be used in logging the data. system_dynamics_handler: SystemDynamicsHandler The system_dynamics_handler is a handler of the state, actions and targets processing funcs as well saved_model_dir: string Defines the saved model directory where the model is saved in, in case of loading the model. save_model_frequency: Int Defines how often the model should be saved (defined relative to the number of refining iters) start_episode: Int the episode index for tensorflow logging purposes exploration_noise: bool Defines if exploration noise should be added to the action to be executed. Returns ------- system_dynamics_handler: SystemDynamicsHandler The system_dynamics_handler holds the trained system dynamics. """ if system_dynamics_handler is None: system_dynamics_handler = SystemDynamicsHandler(env_action_space=env.action_space, env_observation_space=env.observation_space, true_model=False, dynamics_function=dynamics_function, tf_writer=tf_writer, is_normalized=is_normalized, log_dir=log_dir, save_model_frequency=save_model_frequency, saved_model_dir=saved_model_dir) traj_obs, traj_acs, traj_rews = \ perform_rollouts(env, number_of_rollouts, task_horizon, policy, exploration_noise=exploration_noise, tf_writer=tf_writer, start_episode=start_episode) system_dynamics_handler.train(traj_obs, traj_acs, traj_rews, validation_split=validation_split, batch_size=batch_size, learning_rate=learning_rate, epochs=epochs, nn_optimizer=nn_optimizer) return system_dynamics_handler