移除多余的tensor封装,eval使用onnx
This commit is contained in:
parent
7e190c2353
commit
eee9bce7dc
|
@ -132,8 +132,8 @@ def train(flags):
|
||||||
|
|
||||||
def sync_onnx_model(frames):
|
def sync_onnx_model(frames):
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||||
if flags.enable_onnx and position:
|
if flags.enable_onnx:
|
||||||
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
|
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
||||||
onnx_params = learner_model.get_model(position)\
|
onnx_params = learner_model.get_model(position)\
|
||||||
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
|
@ -186,7 +186,7 @@ def train(flags):
|
||||||
for child in parent.children():
|
for child in parent.children():
|
||||||
child.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
|
child.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
|
||||||
|
|
||||||
def batch_and_learn(i, device, position, local_lock, position_lock, lock=threading.Lock()):
|
def batch_and_learn(i, position, local_lock, position_lock, lock=threading.Lock()):
|
||||||
"""Thread target for the learning process."""
|
"""Thread target for the learning process."""
|
||||||
nonlocal frames, position_frames, stats
|
nonlocal frames, position_frames, stats
|
||||||
while frames < flags.total_frames:
|
while frames < flags.total_frames:
|
||||||
|
@ -209,11 +209,10 @@ def train(flags):
|
||||||
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||||
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||||
|
|
||||||
for device in device_iterator:
|
|
||||||
for i in range(flags.num_threads):
|
for i in range(flags.num_threads):
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,device,position,locks[device][position],position_locks[position]))
|
target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,position,locks[device][position],position_locks[position]))
|
||||||
thread.setDaemon(True)
|
thread.setDaemon(True)
|
||||||
thread.start()
|
thread.start()
|
||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
|
|
|
@ -6,12 +6,16 @@ the environment, we do it automatically.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def _format_observation(obs, device):
|
def _format_observation(obs, device, flags):
|
||||||
"""
|
"""
|
||||||
A utility function to process observations and
|
A utility function to process observations and
|
||||||
move them to CUDA.
|
move them to CUDA.
|
||||||
"""
|
"""
|
||||||
position = obs['position']
|
position = obs['position']
|
||||||
|
if flags.enable_onnx:
|
||||||
|
x_batch = obs['x_batch']
|
||||||
|
z_batch = obs['z_batch']
|
||||||
|
else:
|
||||||
if not device == "cpu":
|
if not device == "cpu":
|
||||||
device = 'cuda:' + str(device)
|
device = 'cuda:' + str(device)
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
@ -35,7 +39,7 @@ class Environment:
|
||||||
|
|
||||||
def initial(self, model, device, flags=None):
|
def initial(self, model, device, flags=None):
|
||||||
obs = self.env.reset(model, device, flags=flags)
|
obs = self.env.reset(model, device, flags=flags)
|
||||||
initial_position, initial_obs, x_no_action, z = _format_observation(obs, self.device)
|
initial_position, initial_obs, x_no_action, z = _format_observation(obs, self.device, flags)
|
||||||
self.episode_return = torch.zeros(1, 1)
|
self.episode_return = torch.zeros(1, 1)
|
||||||
initial_done = torch.ones(1, 1, dtype=torch.bool)
|
initial_done = torch.ones(1, 1, dtype=torch.bool)
|
||||||
return initial_position, initial_obs, dict(
|
return initial_position, initial_obs, dict(
|
||||||
|
@ -54,7 +58,7 @@ class Environment:
|
||||||
obs = self.env.reset(model, device, flags=flags)
|
obs = self.env.reset(model, device, flags=flags)
|
||||||
self.episode_return = torch.zeros(1, 1)
|
self.episode_return = torch.zeros(1, 1)
|
||||||
|
|
||||||
position, obs, x_no_action, z = _format_observation(obs, self.device)
|
position, obs, x_no_action, z = _format_observation(obs, self.device, flags)
|
||||||
# reward = torch.tensor(reward).view(1, 1)
|
# reward = torch.tensor(reward).view(1, 1)
|
||||||
done = torch.tensor(done).view(1, 1)
|
done = torch.tensor(done).view(1, 1)
|
||||||
|
|
||||||
|
|
|
@ -421,7 +421,7 @@ class Model:
|
||||||
def set_onnx_model(self, device='cpu'):
|
def set_onnx_model(self, device='cpu'):
|
||||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
for position in positions:
|
for position in positions:
|
||||||
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.savedir, self.flags.xpid, position))
|
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid, position))
|
||||||
if device == 'cpu':
|
if device == 'cpu':
|
||||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
||||||
else:
|
else:
|
||||||
|
@ -431,18 +431,24 @@ class Model:
|
||||||
self.models[position].get_onnx_params(self.device)
|
self.models[position].get_onnx_params(self.device)
|
||||||
|
|
||||||
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
||||||
|
if flags.enable_onnx:
|
||||||
model = self.onnx_models[position]
|
model = self.onnx_models[position]
|
||||||
if model is None:
|
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
||||||
|
values = torch.tensor(onnx_out[0])
|
||||||
|
else:
|
||||||
model = self.models[position]
|
model = self.models[position]
|
||||||
values = model.forward(z, x)['values']
|
values = model.forward(z, x)['values']
|
||||||
else:
|
|
||||||
onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
|
|
||||||
values = torch.tensor(onnx_out[0])
|
|
||||||
if return_value:
|
if return_value:
|
||||||
return dict(values=values)
|
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
|
||||||
else:
|
else:
|
||||||
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
|
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
|
||||||
action = torch.randint(values.shape[0], (1,))[0]
|
if flags.enable_onnx:
|
||||||
|
action = np.random.randint(0, values.shape[0], (1,))[0]
|
||||||
|
else:
|
||||||
|
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
||||||
|
else:
|
||||||
|
if flags.enable_onnx:
|
||||||
|
action = np.argmax(values, axis=0)[0].cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
action = torch.argmax(values,dim=0)[0]
|
action = torch.argmax(values,dim=0)[0]
|
||||||
return dict(action = action)
|
return dict(action = action)
|
||||||
|
|
|
@ -118,7 +118,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
||||||
if len(obs['legal_actions']) > 1:
|
if len(obs['legal_actions']) > 1:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
||||||
_action_idx = int(agent_output['action'].cpu().detach().numpy())
|
_action_idx = int(agent_output['action'])
|
||||||
action = obs['legal_actions'][_action_idx]
|
action = obs['legal_actions'][_action_idx]
|
||||||
else:
|
else:
|
||||||
action = obs['legal_actions'][0]
|
action = obs['legal_actions'][0]
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import onnxruntime
|
||||||
|
from onnxruntime.datasets import get_example
|
||||||
|
|
||||||
from douzero.env.env import get_obs
|
from douzero.env.env import get_obs
|
||||||
|
|
||||||
|
@ -25,6 +28,21 @@ def _load_model(position, model_path, model_type, use_legacy):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
model.cuda()
|
model.cuda()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
onnx_params = model.get_onnx_params(torch.device('cpu'))
|
||||||
|
model_path = model_path + '.onnx'
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
onnx_params['args'],
|
||||||
|
model_path,
|
||||||
|
export_params=True,
|
||||||
|
opset_version=10,
|
||||||
|
do_constant_folding=True,
|
||||||
|
input_names=onnx_params['input_names'],
|
||||||
|
output_names=onnx_params['output_names'],
|
||||||
|
dynamic_axes=onnx_params['dynamic_axes']
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
class DeepAgent:
|
class DeepAgent:
|
||||||
|
@ -33,6 +51,7 @@ class DeepAgent:
|
||||||
self.use_legacy = True if "legacy" in model_path else False
|
self.use_legacy = True if "legacy" in model_path else False
|
||||||
self.model_type = "general" if "resnet" in model_path else "old"
|
self.model_type = "general" if "resnet" in model_path else "old"
|
||||||
self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
|
self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
|
||||||
|
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider'])
|
||||||
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||||
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
||||||
|
@ -46,8 +65,9 @@ class DeepAgent:
|
||||||
x_batch = torch.from_numpy(obs['x_batch']).float()
|
x_batch = torch.from_numpy(obs['x_batch']).float()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
|
z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
|
||||||
y_pred = self.model.forward(z_batch, x_batch)['values']
|
# y_pred = self.model.forward(z_batch, x_batch)['values']
|
||||||
y_pred = y_pred.detach().cpu().numpy()
|
# y_pred = y_pred.detach().cpu().numpy()
|
||||||
|
y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0]
|
||||||
|
|
||||||
best_action_index = np.argmax(y_pred, axis=0)[0]
|
best_action_index = np.argmax(y_pred, axis=0)[0]
|
||||||
best_action = infoset.legal_actions[best_action_index]
|
best_action = infoset.legal_actions[best_action_index]
|
||||||
|
|
Loading…
Reference in New Issue