使用onnx进行infer逻辑,未完成

This commit is contained in:
ZaneYork 2021-12-14 22:55:03 +08:00
parent f054fed61c
commit 0cb3d040cb
6 changed files with 133 additions and 111 deletions

View File

@ -19,6 +19,7 @@ import psutil
import shutil
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']}
onnx_frame = mp.Value('d', -1)
def compute_loss(logits, targets):
loss = ((logits.squeeze(-1) - targets)**2).mean()
@ -79,6 +80,7 @@ def train(flags):
Then it will start subprocesses as actors. Then, it will call
learning function with multiple threads.
"""
global onnx_frame
if not flags.actor_device_cpu or flags.training_device != 'cpu':
if not torch.cuda.is_available():
raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`")
@ -171,7 +173,7 @@ def train(flags):
for i in range(num_actors):
actor = mp.Process(
target=act,
args=(i, device, batch_queues, models[device], flags))
args=(i, device, batch_queues, models[device], flags, onnx_frame))
actor.daemon = True
actor.start()
actor_processes.append(actor)
@ -214,6 +216,7 @@ def train(flags):
threads.append(thread)
def checkpoint(frames):
global onnx_frame
if flags.disable_checkpoint:
return
log.info('Saving checkpoint to %s', checkpointpath)
@ -228,25 +231,32 @@ def train(flags):
}, checkpointpath + '.new')
# Save the weights for evaluation purpose
dummy_input = (torch.tensor(np.zeros([80, 40, 108]), dtype=torch.float32),
torch.tensor(np.zeros((80, 80)), dtype=torch.int8),
{
'return_value': False,
'flags': {'exp_epsilon':0.001}
},
)
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
dummy_input = (
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32),
torch.tensor(np.zeros((1, 80)), dtype=torch.float32)
)
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']:
model_weights_dir = os.path.expandvars(os.path.expanduser(
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt')))
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
if position != 'bidding':
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
torch.onnx.export(
learner_model.get_model(position),
dummy_input,
'%s/model_%s.onnx' % (flags.savedir, position),
input_names=['z_batch','x_batch','flags'],
output_names=['action', 'max_value']
model_path,
input_names=['z_batch','x_batch'],
output_names=['values'],
dynamic_axes={
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
}
}
)
onnx_frame = frames
shutil.move(checkpointpath + '.new', checkpointpath)

View File

@ -6,9 +6,14 @@ models into one class for convenience.
import numpy as np
import torch
import onnxruntime
from onnxruntime.datasets import get_example
from torch import nn
import torch.nn.functional as F
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
class LandlordLstmModel(nn.Module):
def __init__(self):
super().__init__()
@ -20,7 +25,7 @@ class LandlordLstmModel(nn.Module):
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def forward(self, z, x, return_value=False, flags=None):
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
@ -35,14 +40,7 @@ class LandlordLstmModel(nn.Module):
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
return dict(values=x)
class FarmerLstmModel(nn.Module):
def __init__(self):
@ -55,7 +53,7 @@ class FarmerLstmModel(nn.Module):
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def forward(self, z, x, return_value=False, flags=None):
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
@ -70,14 +68,7 @@ class FarmerLstmModel(nn.Module):
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
return dict(values=x)
class LandlordLstmModelLegacy(nn.Module):
def __init__(self):
@ -90,7 +81,7 @@ class LandlordLstmModelLegacy(nn.Module):
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def forward(self, z, x, return_value=False, flags=None):
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
@ -105,14 +96,7 @@ class LandlordLstmModelLegacy(nn.Module):
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
return dict(values=x)
class FarmerLstmModelLegacy(nn.Module):
def __init__(self):
@ -125,7 +109,7 @@ class FarmerLstmModelLegacy(nn.Module):
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def forward(self, z, x, return_value=False, flags=None):
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
@ -140,14 +124,7 @@ class FarmerLstmModelLegacy(nn.Module):
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
return dict(values=x)
class GeneralModel1(nn.Module):
def __init__(self):
@ -185,7 +162,7 @@ class GeneralModel1(nn.Module):
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x, return_value=False, flags=None, debug=False):
def forward(self, z, x):
z = z.unsqueeze(1)
z = self.conv_z_1(z)
z = z.squeeze(-1)
@ -209,14 +186,7 @@ class GeneralModel1(nn.Module):
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action, max_value=torch.max(x))
return dict(values=x)
# 用于ResNet18和34的残差块用的是2个3x3的卷积
@ -278,7 +248,7 @@ class GeneralModelLegacy(nn.Module):
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, z, x, return_value=False, flags=None, debug=False):
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
@ -291,14 +261,7 @@ class GeneralModelLegacy(nn.Module):
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
if return_value:
return dict(values=out)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(out.shape[0], (1,))[0]
else:
action = torch.argmax(out,dim=0)[0]
return dict(action=action, max_value=torch.max(out))
return dict(values=out)
class GeneralModel(nn.Module):
def __init__(self):
@ -329,7 +292,7 @@ class GeneralModel(nn.Module):
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, z, x, return_value=False, flags=None, debug=False):
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
@ -342,17 +305,7 @@ class GeneralModel(nn.Module):
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
if return_value:
return dict(values=out)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(out.shape[0], (1,))[0]
else:
action = torch.argmax(out,dim=0)[0]
return dict(action=action, max_value=torch.max(out))
return dict(values=out)
class BidModel(nn.Module):
@ -366,7 +319,7 @@ class BidModel(nn.Module):
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x, return_value=False, flags=None, debug=False):
def forward(self, z, x):
x = self.dense1(x)
x = F.leaky_relu(x)
# x = F.relu(x)
@ -383,14 +336,7 @@ class BidModel(nn.Module):
# x = F.relu(x)
x = F.leaky_relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action, max_value=torch.max(x))
return dict(values=x)
# Model dict is only used in evaluation but not training
@ -438,9 +384,17 @@ class General_Model:
self.models['landlord_down'] = GeneralModel1().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None, debug=False):
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
model = self.models[position]
return model.forward(z, x, training, flags, debug)
values = model.forward(z, x)['values']
if return_value:
return dict(values=values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(values.shape[0], (1,))[0]
else:
action = torch.argmax(values,dim=0)[0]
return dict(action=action, max_value=torch.max(values))
def share_memory(self):
self.models['landlord'].share_memory()
@ -480,9 +434,17 @@ class OldModel:
self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None):
def forward(self, position, z, x, return_value=False, flags=None):
model = self.models[position]
return model.forward(z, x, training, flags)
values = model.forward(z, x)['values']
if return_value:
return dict(values=values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(values.shape[0], (1,))[0]
else:
action = torch.argmax(values,dim=0)[0]
return dict(action=action, max_value=torch.max(values))
def share_memory(self):
self.models['landlord'].share_memory()
@ -523,10 +485,35 @@ class Model:
self.models['landlord_front'] = GeneralModel().to(torch.device(device))
self.models['landlord_down'] = GeneralModel().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
self.onnx_models = {
'landlord': None,
'landlord_up': None,
'landlord_front': None,
'landlord_down': None,
'bidding': None
}
self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None, debug=False):
model = self.models[position]
return model.forward(z, x, training, flags, debug)
def set_onnx_model(self, position, model_path):
self.onnx_models[position] = get_example(model_path)
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
model = self.onnx_models[position]
if model is None:
model = self.models[position]
values = model.forward(z, x)['values']
else:
sess = onnxruntime.InferenceSession(model)
onnx_out = sess.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
values = torch.tensor(onnx_out[0])
if return_value:
return dict(values=values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(values.shape[0], (1,))[0]
else:
action = torch.argmax(values,dim=0)[0]
return dict(action=action, max_value=torch.max(values))
def share_memory(self):
self.models['landlord'].share_memory()

View File

@ -81,7 +81,7 @@ def create_optimizers(flags, learner_model):
return optimizers
def act(i, device, batch_queues, model, flags):
def act(i, device, batch_queues, model, flags, onnx_frame):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']
for pos in positions:
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
@ -110,8 +110,16 @@ def act(i, device, batch_queues, model, flags):
position, obs, env_output = env.initial(model, device, flags=flags)
bid_obs_buffer = env_output["begin_buf"]["bid_obs_buffer"]
multiply_obs_buffer = env_output["begin_buf"]["multiply_obs_buffer"]
last_onnx_frame = -1
while True:
# print("posi", position)
if onnx_frame != last_onnx_frame:
last_onnx_frame = onnx_frame
for p in positions:
if p != 'bidding':
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, p)
model.set_onnx_model(p, os.path.abspath(model_path))
for bid_obs in bid_obs_buffer:
obs_z_buf["bidding"].append(bid_obs['z_batch'])
obs_x_batch_buf["bidding"].append(bid_obs["x_batch"])

View File

@ -61,8 +61,8 @@ if __name__ == '__main__':
parser.add_argument('--bid', type=bool, default=True)
parser.add_argument('--title', type=str, default='New')
args = parser.parse_args()
# args.output = True
args.output = False
args.output = True
# args.output = False
args.bid = False
if args.output or args.bid:
args.num_workers = 1
@ -72,15 +72,30 @@ if __name__ == '__main__':
eval_list = [
# {
# 'landlord': { 'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 48545600},
# 'farmer': { 'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400},
# 'two_way': True
# 'landlord': {'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 143539200},
# 'farmer': {'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 143539200},
# 'two_way': False
# },
{
'landlord': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400},
'farmer': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400},
'farmer': { 'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 143539200},
'landlord': { 'folder': 'baselines', 'prefix': 'resnet', 'frame': 23358800},
'two_way': True
}
},
# {
# 'landlord': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400},
# 'farmer': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400},
# 'two_way': False
# },
# {
# 'landlord': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400},
# 'farmer': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400},
# 'two_way': False
# },
# {
# 'landlord': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400},
# 'farmer': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400},
# 'two_way': True
# },
]
for vs in reversed(eval_list):

View File

@ -15,8 +15,8 @@ deck.extend([20, 20, 30, 30])
def get_parser():
parser = argparse.ArgumentParser(description='DouZero: random data generator')
parser.add_argument('--output', default='eval_data_200', type=str)
parser.add_argument('--path', default='baselines/resnet_bidding_27853600.ckpt', type=str)
parser.add_argument('--output', default='eval_data_500', type=str)
parser.add_argument('--path', default='baselines/resnet_landlord_23358800.ckpt', type=str)
parser.add_argument('--num_games', default=200, type=int)
parser.add_argument('--exp_epsilon', default=0.01, type=float)
return parser

View File

@ -3,3 +3,5 @@ GitPython
gitdb2
rlcard
psutil
onnx
onnxruntime