2021-09-07 16:38:34 +08:00
|
|
|
"""
|
|
|
|
Here, we wrap the original environment to make it easier
|
|
|
|
to use. When a game is finished, instead of mannualy reseting
|
|
|
|
the environment, we do it automatically.
|
|
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
2021-12-21 10:46:24 +08:00
|
|
|
def _format_observation(obs, device, flags):
|
2021-09-07 16:38:34 +08:00
|
|
|
"""
|
|
|
|
A utility function to process observations and
|
|
|
|
move them to CUDA.
|
|
|
|
"""
|
|
|
|
position = obs['position']
|
2021-12-21 10:46:24 +08:00
|
|
|
if flags.enable_onnx:
|
|
|
|
x_batch = obs['x_batch']
|
|
|
|
z_batch = obs['z_batch']
|
|
|
|
else:
|
|
|
|
if not device == "cpu":
|
|
|
|
device = 'cuda:' + str(device)
|
|
|
|
device = torch.device(device)
|
|
|
|
x_batch = torch.from_numpy(obs['x_batch']).to(device)
|
|
|
|
z_batch = torch.from_numpy(obs['z_batch']).to(device)
|
2021-09-07 16:38:34 +08:00
|
|
|
x_no_action = torch.from_numpy(obs['x_no_action'])
|
|
|
|
z = torch.from_numpy(obs['z'])
|
|
|
|
obs = {'x_batch': x_batch,
|
|
|
|
'z_batch': z_batch,
|
|
|
|
'legal_actions': obs['legal_actions'],
|
|
|
|
}
|
|
|
|
return position, obs, x_no_action, z
|
|
|
|
|
|
|
|
class Environment:
|
|
|
|
def __init__(self, env, device):
|
|
|
|
""" Initialzie this environment wrapper
|
|
|
|
"""
|
|
|
|
self.env = env
|
|
|
|
self.device = device
|
|
|
|
self.episode_return = None
|
|
|
|
|
|
|
|
def initial(self, model, device, flags=None):
|
2021-12-19 17:19:32 +08:00
|
|
|
obs = self.env.reset(model, device, flags=flags)
|
2021-12-21 10:46:24 +08:00
|
|
|
initial_position, initial_obs, x_no_action, z = _format_observation(obs, self.device, flags)
|
2021-09-07 16:38:34 +08:00
|
|
|
self.episode_return = torch.zeros(1, 1)
|
|
|
|
initial_done = torch.ones(1, 1, dtype=torch.bool)
|
2021-12-19 17:19:32 +08:00
|
|
|
return initial_position, initial_obs, dict(
|
|
|
|
done=initial_done,
|
|
|
|
episode_return=self.episode_return,
|
|
|
|
obs_x_no_action=x_no_action,
|
|
|
|
obs_z=z,
|
|
|
|
)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
def step(self, action, model, device, flags=None):
|
|
|
|
obs, reward, done, _ = self.env.step(action)
|
|
|
|
|
|
|
|
self.episode_return = reward
|
|
|
|
episode_return = self.episode_return
|
|
|
|
if done:
|
2021-12-19 17:19:32 +08:00
|
|
|
obs = self.env.reset(model, device, flags=flags)
|
2021-09-07 16:38:34 +08:00
|
|
|
self.episode_return = torch.zeros(1, 1)
|
|
|
|
|
2021-12-21 10:46:24 +08:00
|
|
|
position, obs, x_no_action, z = _format_observation(obs, self.device, flags)
|
2021-09-07 16:38:34 +08:00
|
|
|
# reward = torch.tensor(reward).view(1, 1)
|
|
|
|
done = torch.tensor(done).view(1, 1)
|
|
|
|
|
2021-12-19 17:19:32 +08:00
|
|
|
return position, obs, dict(
|
|
|
|
done=done,
|
|
|
|
episode_return=episode_return,
|
|
|
|
obs_x_no_action=x_no_action,
|
|
|
|
obs_z=z,
|
|
|
|
)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
def close(self):
|
|
|
|
self.env.close()
|