使用onnx-gpu进行infer逻辑,未完成

This commit is contained in:
ZaneYork 2021-12-15 22:09:18 +08:00
parent 601dafb008
commit 3369b491e2
5 changed files with 67 additions and 48 deletions

View File

@ -1 +1 @@
python train.py --load_model --batch_size 8 --learning_rate 0.0003
python train.py --load_model --batch_size 8 --learning_rate 0.0003 --enable_onnx

View File

@ -68,6 +68,7 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock):
nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
optimizer.step()
if not flags.enable_onnx:
for actor_model in actor_models.values():
actor_model.get_model(position).load_state_dict(model.state_dict())
return stats
@ -103,9 +104,9 @@ def train(flags):
models = {}
for device in device_iterator:
if flags.old_model:
model = OldModel(device="cpu")
model = OldModel(device="cpu", flags = flags)
else:
model = Model(device="cpu")
model = Model(device="cpu", flags = flags)
model.share_memory()
model.eval()
models[device] = model
@ -149,6 +150,7 @@ def train(flags):
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_down']
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
if not flags.enable_onnx:
for device in device_iterator:
models[device].get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
stats = checkpoint_states["stats"]

View File

@ -2,6 +2,7 @@
This file includes the torch models. We wrap the three
models into one class for convenience.
"""
import os
import numpy as np
@ -331,8 +332,8 @@ class GeneralModel(nn.Module):
def get_onnx_params(self):
return {
'args': (
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32),
torch.tensor(np.zeros((1, 80)), dtype=torch.float32)
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32, device='cuda:0'),
torch.tensor(np.zeros((1, 80)), dtype=torch.float32, device='cuda:0')
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
@ -478,7 +479,7 @@ class OldModel:
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0):
def __init__(self, device=0, flags=None):
self.models = {}
if not device == "cpu":
device = 'cuda:' + str(device)
@ -519,6 +520,7 @@ class OldModel:
return dict(action=action)
def share_memory(self):
if self.models['landlord'] is not None:
self.models['landlord'].share_memory()
self.models['landlord_up'].share_memory()
self.models['landlord_front'].share_memory()
@ -526,6 +528,7 @@ class OldModel:
self.models['bidding'].share_memory()
def eval(self):
if self.models['landlord'] is not None:
self.models['landlord'].eval()
self.models['landlord_up'].eval()
self.models['landlord_front'].eval()
@ -547,15 +550,21 @@ class Model:
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0):
def __init__(self, device=0, flags=None):
self.models = {}
self.onnx_models = {}
self.flags = flags
if not device == "cpu":
device = 'cuda:' + str(device)
# model = GeneralModel().to(torch.device(device))
self.models['landlord'] = GeneralModel().to(torch.device(device))
self.models['landlord_up'] = GeneralModel().to(torch.device(device))
self.models['landlord_front'] = GeneralModel().to(torch.device(device))
self.models['landlord_down'] = GeneralModel().to(torch.device(device))
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if flags is not None and flags.enable_onnx:
self.models['bidding'] = BidModel().to(torch.device(device))
for position in positions:
self.models[position] = None
else:
for position in positions:
self.models[position] = GeneralModel().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
self.onnx_models = {
'landlord': None,
@ -564,15 +573,20 @@ class Model:
'landlord_down': None,
'bidding': None
}
self.models['bidding'] = BidModel().to(torch.device(device))
def set_onnx_model(self, position, model_path):
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
def set_onnx_model(self):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
for position in positions:
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.savedir, self.flags.xpid, position))
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.onnx_models['bidding'] = None
def get_onnx_params(self, position):
self.models[position].get_onnx_params()
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
if self.flags.enable_onnx and len(self.onnx_models) == 0:
self.set_onnx_model()
model = self.onnx_models[position]
if model is None:
model = self.models[position]
@ -590,6 +604,7 @@ class Model:
return dict(action=action)
def share_memory(self):
if self.models['landlord'] is not None:
self.models['landlord'].share_memory()
self.models['landlord_up'].share_memory()
self.models['landlord_front'].share_memory()
@ -597,6 +612,7 @@ class Model:
self.models['bidding'].share_memory()
def eval(self):
if self.models['landlord'] is not None:
self.models['landlord'].eval()
self.models['landlord_up'].eval()
self.models['landlord_front'].eval()

View File

@ -83,8 +83,11 @@ def create_optimizers(flags, learner_model):
def act(i, device, batch_queues, model, flags, onnx_frame):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']
if not flags.enable_onnx:
for pos in positions:
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
else:
model.models['bidding'].to(torch.device(device if device == "cpu" else ("cuda:" + str(device))))
try:
T = flags.unroll_length
log.info('Device %s Actor %i started.', str(device), i)
@ -117,9 +120,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
last_onnx_frame = onnx_frame.value
for p in positions:
if p != 'bidding':
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, p)
if os.path.exists(model_path):
model.set_onnx_model(p, os.path.abspath(model_path))
model.set_onnx_model()
for bid_obs in bid_obs_buffer:
obs_z_buf["bidding"].append(bid_obs['z_batch'])

View File

@ -4,4 +4,4 @@ gitdb2
rlcard
psutil
onnx
onnxruntime
onnxruntime-gpu