From eee9bce7dcf931f8d802de2d20c0d362d1d507fc Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Tue, 21 Dec 2021 10:46:24 +0800 Subject: [PATCH] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E5=A4=9A=E4=BD=99=E7=9A=84te?= =?UTF-8?q?nsor=E5=B0=81=E8=A3=85=EF=BC=8Ceval=E4=BD=BF=E7=94=A8onnx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/dmc.py | 21 ++++++++++----------- douzero/dmc/env_utils.py | 20 ++++++++++++-------- douzero/dmc/models.py | 26 ++++++++++++++++---------- douzero/dmc/utils.py | 2 +- douzero/evaluation/deep_agent.py | 24 ++++++++++++++++++++++-- 5 files changed, 61 insertions(+), 32 deletions(-) diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index a7d3c21..3b3ca33 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -132,8 +132,8 @@ def train(flags): def sync_onnx_model(frames): for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: - if flags.enable_onnx and position: - model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position) + if flags.enable_onnx: + model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position) onnx_params = learner_model.get_model(position)\ .get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device)) torch.onnx.export( @@ -186,7 +186,7 @@ def train(flags): for child in parent.children(): child.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) - def batch_and_learn(i, device, position, local_lock, position_lock, lock=threading.Lock()): + def batch_and_learn(i, position, local_lock, position_lock, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal frames, position_frames, stats while frames < flags.total_frames: @@ -209,14 +209,13 @@ def train(flags): locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} - for device in device_iterator: - for i in range(flags.num_threads): - for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: - thread = threading.Thread( - target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,device,position,locks[device][position],position_locks[position])) - thread.setDaemon(True) - thread.start() - threads.append(thread) + for i in range(flags.num_threads): + for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: + thread = threading.Thread( + target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,position,locks[device][position],position_locks[position])) + thread.setDaemon(True) + thread.start() + threads.append(thread) def checkpoint(frames): if flags.disable_checkpoint: diff --git a/douzero/dmc/env_utils.py b/douzero/dmc/env_utils.py index dcb0585..f64fc14 100644 --- a/douzero/dmc/env_utils.py +++ b/douzero/dmc/env_utils.py @@ -6,17 +6,21 @@ the environment, we do it automatically. import numpy as np import torch -def _format_observation(obs, device): +def _format_observation(obs, device, flags): """ A utility function to process observations and move them to CUDA. """ position = obs['position'] - if not device == "cpu": - device = 'cuda:' + str(device) - device = torch.device(device) - x_batch = torch.from_numpy(obs['x_batch']).to(device) - z_batch = torch.from_numpy(obs['z_batch']).to(device) + if flags.enable_onnx: + x_batch = obs['x_batch'] + z_batch = obs['z_batch'] + else: + if not device == "cpu": + device = 'cuda:' + str(device) + device = torch.device(device) + x_batch = torch.from_numpy(obs['x_batch']).to(device) + z_batch = torch.from_numpy(obs['z_batch']).to(device) x_no_action = torch.from_numpy(obs['x_no_action']) z = torch.from_numpy(obs['z']) obs = {'x_batch': x_batch, @@ -35,7 +39,7 @@ class Environment: def initial(self, model, device, flags=None): obs = self.env.reset(model, device, flags=flags) - initial_position, initial_obs, x_no_action, z = _format_observation(obs, self.device) + initial_position, initial_obs, x_no_action, z = _format_observation(obs, self.device, flags) self.episode_return = torch.zeros(1, 1) initial_done = torch.ones(1, 1, dtype=torch.bool) return initial_position, initial_obs, dict( @@ -54,7 +58,7 @@ class Environment: obs = self.env.reset(model, device, flags=flags) self.episode_return = torch.zeros(1, 1) - position, obs, x_no_action, z = _format_observation(obs, self.device) + position, obs, x_no_action, z = _format_observation(obs, self.device, flags) # reward = torch.tensor(reward).view(1, 1) done = torch.tensor(done).view(1, 1) diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index 8aeee5a..a92727e 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -421,7 +421,7 @@ class Model: def set_onnx_model(self, device='cpu'): positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] for position in positions: - model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.savedir, self.flags.xpid, position)) + model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid, position)) if device == 'cpu': self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider']) else: @@ -431,21 +431,27 @@ class Model: self.models[position].get_onnx_params(self.device) def forward(self, position, z, x, return_value=False, flags=None, debug=False): - model = self.onnx_models[position] - if model is None: + if flags.enable_onnx: + model = self.onnx_models[position] + onnx_out = model.run(None, {'z_batch': z, 'x_batch': x}) + values = torch.tensor(onnx_out[0]) + else: model = self.models[position] values = model.forward(z, x)['values'] - else: - onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)}) - values = torch.tensor(onnx_out[0]) if return_value: - return dict(values=values) + return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy()) else: if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - action = torch.randint(values.shape[0], (1,))[0] + if flags.enable_onnx: + action = np.random.randint(0, values.shape[0], (1,))[0] + else: + action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy() else: - action = torch.argmax(values,dim=0)[0] - return dict(action=action) + if flags.enable_onnx: + action = np.argmax(values, axis=0)[0].cpu().detach().numpy() + else: + action = torch.argmax(values,dim=0)[0] + return dict(action = action) def share_memory(self): if self.models['landlord'] is not None: diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index 29d1d2e..ac0025b 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -118,7 +118,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame): if len(obs['legal_actions']) > 1: with torch.no_grad(): agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags) - _action_idx = int(agent_output['action'].cpu().detach().numpy()) + _action_idx = int(agent_output['action']) action = obs['legal_actions'][_action_idx] else: action = obs['legal_actions'][0] diff --git a/douzero/evaluation/deep_agent.py b/douzero/evaluation/deep_agent.py index b987dd8..5fdf6c1 100644 --- a/douzero/evaluation/deep_agent.py +++ b/douzero/evaluation/deep_agent.py @@ -1,5 +1,8 @@ import torch import numpy as np +import os +import onnxruntime +from onnxruntime.datasets import get_example from douzero.env.env import get_obs @@ -25,6 +28,21 @@ def _load_model(position, model_path, model_type, use_legacy): if torch.cuda.is_available(): model.cuda() model.eval() + onnx_params = model.get_onnx_params(torch.device('cpu')) + model_path = model_path + '.onnx' + if not os.path.exists(model_path): + torch.onnx.export( + model, + onnx_params['args'], + model_path, + export_params=True, + opset_version=10, + do_constant_folding=True, + input_names=onnx_params['input_names'], + output_names=onnx_params['output_names'], + dynamic_axes=onnx_params['dynamic_axes'] + ) + return model class DeepAgent: @@ -33,6 +51,7 @@ class DeepAgent: self.use_legacy = True if "legacy" in model_path else False self.model_type = "general" if "resnet" in model_path else "old" self.model = _load_model(position, model_path, self.model_type, self.use_legacy) + self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider']) self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', 13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'} @@ -46,8 +65,9 @@ class DeepAgent: x_batch = torch.from_numpy(obs['x_batch']).float() if torch.cuda.is_available(): z_batch, x_batch = z_batch.cuda(), x_batch.cuda() - y_pred = self.model.forward(z_batch, x_batch)['values'] - y_pred = y_pred.detach().cpu().numpy() + # y_pred = self.model.forward(z_batch, x_batch)['values'] + # y_pred = y_pred.detach().cpu().numpy() + y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0] best_action_index = np.argmax(y_pred, axis=0)[0] best_action = infoset.legal_actions[best_action_index]