Source code for blackbox_mpc.utils.recording

from gym.wrappers.monitoring.video_recorder import VideoRecorder
from blackbox_mpc.policies.model_free_base_policy import ModelFreeBasePolicy


[docs]def record_rollout(env, horizon, policy, record_file_path): """ This is the recording function for the runner class which samples one episode with a specified length using the provided policy and records it in a video. Parameters --------- horizon: Int The task horizon/ episode length. policy: ModelBasedBasePolicy or ModelFreeBasePolicy The policy to be used in collecting the episodes from the different agents. record_file_path: String specified the file path to save the video that will be recorded in. """ recorder = VideoRecorder(env, record_file_path + '.mp4') observations = env.reset() for t in range(horizon): recorder.capture_frame() if not isinstance(policy, ModelFreeBasePolicy): action_to_execute, expected_obs, expected_reward = policy.act( observations, t) else: action_to_execute = policy.act(observations, t) observations, reward, done, info = env.step(action_to_execute) recorder.capture_frame() recorder.close() return