""" Here, we wrap the original environment to make it easier to use. When a game is finished, instead of mannualy reseting the environment, we do it automatically. """ import numpy as np import torch def _format_observation(obs): """ A utility function to process observations and move them to CUDA. """ position = obs['position'] x_batch = obs['x_batch'] z_batch = obs['z_batch'] x_no_action = torch.from_numpy(obs['x_no_action']) z = torch.from_numpy(obs['z']) obs = {'x_batch': x_batch, 'z_batch': z_batch, 'legal_actions': obs['legal_actions'], } return position, obs, x_no_action, z class Environment: def __init__(self, env, device): """ Initialzie this environment wrapper """ self.env = env self.device = device self.episode_return = None def initial(self, flags=None): obs = self.env.reset(flags=flags) initial_position, initial_obs, x_no_action, z = _format_observation(obs) self.episode_return = torch.zeros(1, 1) initial_done = torch.ones(1, 1, dtype=torch.bool) return initial_position, initial_obs, dict( done=initial_done, episode_return=self.episode_return, obs_x_no_action=x_no_action, obs_z=z, ) def step(self, action, flags=None): obs, reward, done, _ = self.env.step(action) self.episode_return = reward episode_return = self.episode_return if done: obs = self.env.reset(flags=flags) self.episode_return = torch.zeros(1, 1) position, obs, x_no_action, z = _format_observation(obs) # reward = torch.tensor(reward).view(1, 1) done = torch.tensor(done).view(1, 1) return position, obs, dict( done=done, episode_return=episode_return, obs_x_no_action=x_no_action, obs_z=z, ) def close(self): self.env.close()