使用onnx-gpu进行infer逻辑,未完成
This commit is contained in:
parent
601dafb008
commit
3369b491e2
|
@ -1 +1 @@
|
||||||
python train.py --load_model --batch_size 8 --learning_rate 0.0003
|
python train.py --load_model --batch_size 8 --learning_rate 0.0003 --enable_onnx
|
|
@ -68,6 +68,7 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock):
|
||||||
nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
|
nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
if not flags.enable_onnx:
|
||||||
for actor_model in actor_models.values():
|
for actor_model in actor_models.values():
|
||||||
actor_model.get_model(position).load_state_dict(model.state_dict())
|
actor_model.get_model(position).load_state_dict(model.state_dict())
|
||||||
return stats
|
return stats
|
||||||
|
@ -103,9 +104,9 @@ def train(flags):
|
||||||
models = {}
|
models = {}
|
||||||
for device in device_iterator:
|
for device in device_iterator:
|
||||||
if flags.old_model:
|
if flags.old_model:
|
||||||
model = OldModel(device="cpu")
|
model = OldModel(device="cpu", flags = flags)
|
||||||
else:
|
else:
|
||||||
model = Model(device="cpu")
|
model = Model(device="cpu", flags = flags)
|
||||||
model.share_memory()
|
model.share_memory()
|
||||||
model.eval()
|
model.eval()
|
||||||
models[device] = model
|
models[device] = model
|
||||||
|
@ -149,6 +150,7 @@ def train(flags):
|
||||||
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_down']
|
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_down']
|
||||||
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||||
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
||||||
|
if not flags.enable_onnx:
|
||||||
for device in device_iterator:
|
for device in device_iterator:
|
||||||
models[device].get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
models[device].get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||||
stats = checkpoint_states["stats"]
|
stats = checkpoint_states["stats"]
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
This file includes the torch models. We wrap the three
|
This file includes the torch models. We wrap the three
|
||||||
models into one class for convenience.
|
models into one class for convenience.
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -331,8 +332,8 @@ class GeneralModel(nn.Module):
|
||||||
def get_onnx_params(self):
|
def get_onnx_params(self):
|
||||||
return {
|
return {
|
||||||
'args': (
|
'args': (
|
||||||
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32),
|
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32, device='cuda:0'),
|
||||||
torch.tensor(np.zeros((1, 80)), dtype=torch.float32)
|
torch.tensor(np.zeros((1, 80)), dtype=torch.float32, device='cuda:0')
|
||||||
),
|
),
|
||||||
'input_names': ['z_batch','x_batch'],
|
'input_names': ['z_batch','x_batch'],
|
||||||
'output_names': ['values'],
|
'output_names': ['values'],
|
||||||
|
@ -478,7 +479,7 @@ class OldModel:
|
||||||
The wrapper for the three models. We also wrap several
|
The wrapper for the three models. We also wrap several
|
||||||
interfaces such as share_memory, eval, etc.
|
interfaces such as share_memory, eval, etc.
|
||||||
"""
|
"""
|
||||||
def __init__(self, device=0):
|
def __init__(self, device=0, flags=None):
|
||||||
self.models = {}
|
self.models = {}
|
||||||
if not device == "cpu":
|
if not device == "cpu":
|
||||||
device = 'cuda:' + str(device)
|
device = 'cuda:' + str(device)
|
||||||
|
@ -519,6 +520,7 @@ class OldModel:
|
||||||
return dict(action=action)
|
return dict(action=action)
|
||||||
|
|
||||||
def share_memory(self):
|
def share_memory(self):
|
||||||
|
if self.models['landlord'] is not None:
|
||||||
self.models['landlord'].share_memory()
|
self.models['landlord'].share_memory()
|
||||||
self.models['landlord_up'].share_memory()
|
self.models['landlord_up'].share_memory()
|
||||||
self.models['landlord_front'].share_memory()
|
self.models['landlord_front'].share_memory()
|
||||||
|
@ -526,6 +528,7 @@ class OldModel:
|
||||||
self.models['bidding'].share_memory()
|
self.models['bidding'].share_memory()
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
|
if self.models['landlord'] is not None:
|
||||||
self.models['landlord'].eval()
|
self.models['landlord'].eval()
|
||||||
self.models['landlord_up'].eval()
|
self.models['landlord_up'].eval()
|
||||||
self.models['landlord_front'].eval()
|
self.models['landlord_front'].eval()
|
||||||
|
@ -547,15 +550,21 @@ class Model:
|
||||||
The wrapper for the three models. We also wrap several
|
The wrapper for the three models. We also wrap several
|
||||||
interfaces such as share_memory, eval, etc.
|
interfaces such as share_memory, eval, etc.
|
||||||
"""
|
"""
|
||||||
def __init__(self, device=0):
|
def __init__(self, device=0, flags=None):
|
||||||
self.models = {}
|
self.models = {}
|
||||||
|
self.onnx_models = {}
|
||||||
|
self.flags = flags
|
||||||
if not device == "cpu":
|
if not device == "cpu":
|
||||||
device = 'cuda:' + str(device)
|
device = 'cuda:' + str(device)
|
||||||
# model = GeneralModel().to(torch.device(device))
|
# model = GeneralModel().to(torch.device(device))
|
||||||
self.models['landlord'] = GeneralModel().to(torch.device(device))
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
self.models['landlord_up'] = GeneralModel().to(torch.device(device))
|
if flags is not None and flags.enable_onnx:
|
||||||
self.models['landlord_front'] = GeneralModel().to(torch.device(device))
|
self.models['bidding'] = BidModel().to(torch.device(device))
|
||||||
self.models['landlord_down'] = GeneralModel().to(torch.device(device))
|
for position in positions:
|
||||||
|
self.models[position] = None
|
||||||
|
else:
|
||||||
|
for position in positions:
|
||||||
|
self.models[position] = GeneralModel().to(torch.device(device))
|
||||||
self.models['bidding'] = BidModel().to(torch.device(device))
|
self.models['bidding'] = BidModel().to(torch.device(device))
|
||||||
self.onnx_models = {
|
self.onnx_models = {
|
||||||
'landlord': None,
|
'landlord': None,
|
||||||
|
@ -564,15 +573,20 @@ class Model:
|
||||||
'landlord_down': None,
|
'landlord_down': None,
|
||||||
'bidding': None
|
'bidding': None
|
||||||
}
|
}
|
||||||
self.models['bidding'] = BidModel().to(torch.device(device))
|
|
||||||
|
|
||||||
def set_onnx_model(self, position, model_path):
|
def set_onnx_model(self):
|
||||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
|
for position in positions:
|
||||||
|
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.savedir, self.flags.xpid, position))
|
||||||
|
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||||
|
self.onnx_models['bidding'] = None
|
||||||
|
|
||||||
def get_onnx_params(self, position):
|
def get_onnx_params(self, position):
|
||||||
self.models[position].get_onnx_params()
|
self.models[position].get_onnx_params()
|
||||||
|
|
||||||
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
||||||
|
if self.flags.enable_onnx and len(self.onnx_models) == 0:
|
||||||
|
self.set_onnx_model()
|
||||||
model = self.onnx_models[position]
|
model = self.onnx_models[position]
|
||||||
if model is None:
|
if model is None:
|
||||||
model = self.models[position]
|
model = self.models[position]
|
||||||
|
@ -590,6 +604,7 @@ class Model:
|
||||||
return dict(action=action)
|
return dict(action=action)
|
||||||
|
|
||||||
def share_memory(self):
|
def share_memory(self):
|
||||||
|
if self.models['landlord'] is not None:
|
||||||
self.models['landlord'].share_memory()
|
self.models['landlord'].share_memory()
|
||||||
self.models['landlord_up'].share_memory()
|
self.models['landlord_up'].share_memory()
|
||||||
self.models['landlord_front'].share_memory()
|
self.models['landlord_front'].share_memory()
|
||||||
|
@ -597,6 +612,7 @@ class Model:
|
||||||
self.models['bidding'].share_memory()
|
self.models['bidding'].share_memory()
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
|
if self.models['landlord'] is not None:
|
||||||
self.models['landlord'].eval()
|
self.models['landlord'].eval()
|
||||||
self.models['landlord_up'].eval()
|
self.models['landlord_up'].eval()
|
||||||
self.models['landlord_front'].eval()
|
self.models['landlord_front'].eval()
|
||||||
|
|
|
@ -83,8 +83,11 @@ def create_optimizers(flags, learner_model):
|
||||||
|
|
||||||
def act(i, device, batch_queues, model, flags, onnx_frame):
|
def act(i, device, batch_queues, model, flags, onnx_frame):
|
||||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']
|
||||||
|
if not flags.enable_onnx:
|
||||||
for pos in positions:
|
for pos in positions:
|
||||||
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
|
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
|
||||||
|
else:
|
||||||
|
model.models['bidding'].to(torch.device(device if device == "cpu" else ("cuda:" + str(device))))
|
||||||
try:
|
try:
|
||||||
T = flags.unroll_length
|
T = flags.unroll_length
|
||||||
log.info('Device %s Actor %i started.', str(device), i)
|
log.info('Device %s Actor %i started.', str(device), i)
|
||||||
|
@ -117,9 +120,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
||||||
last_onnx_frame = onnx_frame.value
|
last_onnx_frame = onnx_frame.value
|
||||||
for p in positions:
|
for p in positions:
|
||||||
if p != 'bidding':
|
if p != 'bidding':
|
||||||
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, p)
|
model.set_onnx_model()
|
||||||
if os.path.exists(model_path):
|
|
||||||
model.set_onnx_model(p, os.path.abspath(model_path))
|
|
||||||
|
|
||||||
for bid_obs in bid_obs_buffer:
|
for bid_obs in bid_obs_buffer:
|
||||||
obs_z_buf["bidding"].append(bid_obs['z_batch'])
|
obs_z_buf["bidding"].append(bid_obs['z_batch'])
|
||||||
|
|
|
@ -4,4 +4,4 @@ gitdb2
|
||||||
rlcard
|
rlcard
|
||||||
psutil
|
psutil
|
||||||
onnx
|
onnx
|
||||||
onnxruntime
|
onnxruntime-gpu
|
||||||
|
|
Loading…
Reference in New Issue