使用onnx进行infer逻辑,未完成

This commit is contained in:
ZaneYork 2021-12-14 22:55:03 +08:00
parent f054fed61c
commit 0cb3d040cb
6 changed files with 133 additions and 111 deletions

View File

@ -19,6 +19,7 @@ import psutil
import shutil
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']}
onnx_frame = mp.Value('d', -1)
def compute_loss(logits, targets):
loss = ((logits.squeeze(-1) - targets)**2).mean()
@ -72,13 +73,14 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock):
actor_model.get_model(position).load_state_dict(model.state_dict())
return stats
def train(flags):
def train(flags):
"""
This is the main funtion for training. It will first
initilize everything, such as buffers, optimizers, etc.
Then it will start subprocesses as actors. Then, it will call
learning function with multiple threads.
"""
global onnx_frame
if not flags.actor_device_cpu or flags.training_device != 'cpu':
if not torch.cuda.is_available():
raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`")
@ -171,7 +173,7 @@ def train(flags):
for i in range(num_actors):
actor = mp.Process(
target=act,
args=(i, device, batch_queues, models[device], flags))
args=(i, device, batch_queues, models[device], flags, onnx_frame))
actor.daemon = True
actor.start()
actor_processes.append(actor)
@ -186,7 +188,7 @@ def train(flags):
nonlocal frames, position_frames, stats
while frames < flags.total_frames:
batch = get_batch(batch_queues, position, flags, local_lock)
_stats = learn(position, models, learner_model.get_model(position), batch,
_stats = learn(position, models, learner_model.get_model(position), batch,
optimizers[position], flags, position_lock)
with lock:
for k in _stats:
@ -212,8 +214,9 @@ def train(flags):
thread.setDaemon(True)
thread.start()
threads.append(thread)
def checkpoint(frames):
global onnx_frame
if flags.disable_checkpoint:
return
log.info('Saving checkpoint to %s', checkpointpath)
@ -228,25 +231,32 @@ def train(flags):
}, checkpointpath + '.new')
# Save the weights for evaluation purpose
dummy_input = (torch.tensor(np.zeros([80, 40, 108]), dtype=torch.float32),
torch.tensor(np.zeros((80, 80)), dtype=torch.int8),
{
'return_value': False,
'flags': {'exp_epsilon':0.001}
},
)
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
dummy_input = (
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32),
torch.tensor(np.zeros((1, 80)), dtype=torch.float32)
)
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']:
model_weights_dir = os.path.expandvars(os.path.expanduser(
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt')))
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
if position != 'bidding':
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
torch.onnx.export(
learner_model.get_model(position),
dummy_input,
'%s/model_%s.onnx' % (flags.savedir, position),
input_names=['z_batch','x_batch','flags'],
output_names=['action', 'max_value']
model_path,
input_names=['z_batch','x_batch'],
output_names=['values'],
dynamic_axes={
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
}
}
)
onnx_frame = frames
shutil.move(checkpointpath + '.new', checkpointpath)
@ -260,7 +270,7 @@ def train(flags):
start_time = timer()
time.sleep(5)
if timer() - last_checkpoint_time > flags.save_interval * 60:
if timer() - last_checkpoint_time > flags.save_interval * 60:
checkpoint(frames)
last_checkpoint_time = timer()
end_time = timer()
@ -287,7 +297,7 @@ def train(flags):
pprint.pformat(stats))
except KeyboardInterrupt:
return
return
else:
for thread in threads:
thread.join()

View File

@ -6,9 +6,14 @@ models into one class for convenience.
import numpy as np
import torch
import onnxruntime
from onnxruntime.datasets import get_example
from torch import nn
import torch.nn.functional as F
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
class LandlordLstmModel(nn.Module):
def __init__(self):
super().__init__()
@ -20,7 +25,7 @@ class LandlordLstmModel(nn.Module):
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def forward(self, z, x, return_value=False, flags=None):
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
@ -35,14 +40,7 @@ class LandlordLstmModel(nn.Module):
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
return dict(values=x)
class FarmerLstmModel(nn.Module):
def __init__(self):
@ -55,7 +53,7 @@ class FarmerLstmModel(nn.Module):
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def forward(self, z, x, return_value=False, flags=None):
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
@ -70,14 +68,7 @@ class FarmerLstmModel(nn.Module):
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
return dict(values=x)
class LandlordLstmModelLegacy(nn.Module):
def __init__(self):
@ -90,7 +81,7 @@ class LandlordLstmModelLegacy(nn.Module):
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def forward(self, z, x, return_value=False, flags=None):
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
@ -105,14 +96,7 @@ class LandlordLstmModelLegacy(nn.Module):
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
return dict(values=x)
class FarmerLstmModelLegacy(nn.Module):
def __init__(self):
@ -125,7 +109,7 @@ class FarmerLstmModelLegacy(nn.Module):
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def forward(self, z, x, return_value=False, flags=None):
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
@ -140,14 +124,7 @@ class FarmerLstmModelLegacy(nn.Module):
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
return dict(values=x)
class GeneralModel1(nn.Module):
def __init__(self):
@ -185,7 +162,7 @@ class GeneralModel1(nn.Module):
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x, return_value=False, flags=None, debug=False):
def forward(self, z, x):
z = z.unsqueeze(1)
z = self.conv_z_1(z)
z = z.squeeze(-1)
@ -209,14 +186,7 @@ class GeneralModel1(nn.Module):
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action, max_value=torch.max(x))
return dict(values=x)
# 用于ResNet18和34的残差块用的是2个3x3的卷积
@ -278,7 +248,7 @@ class GeneralModelLegacy(nn.Module):
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, z, x, return_value=False, flags=None, debug=False):
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
@ -291,14 +261,7 @@ class GeneralModelLegacy(nn.Module):
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
if return_value:
return dict(values=out)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(out.shape[0], (1,))[0]
else:
action = torch.argmax(out,dim=0)[0]
return dict(action=action, max_value=torch.max(out))
return dict(values=out)
class GeneralModel(nn.Module):
def __init__(self):
@ -329,7 +292,7 @@ class GeneralModel(nn.Module):
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, z, x, return_value=False, flags=None, debug=False):
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
@ -342,17 +305,7 @@ class GeneralModel(nn.Module):
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
if return_value:
return dict(values=out)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(out.shape[0], (1,))[0]
else:
action = torch.argmax(out,dim=0)[0]
return dict(action=action, max_value=torch.max(out))
return dict(values=out)
class BidModel(nn.Module):
@ -366,7 +319,7 @@ class BidModel(nn.Module):
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x, return_value=False, flags=None, debug=False):
def forward(self, z, x):
x = self.dense1(x)
x = F.leaky_relu(x)
# x = F.relu(x)
@ -383,14 +336,7 @@ class BidModel(nn.Module):
# x = F.relu(x)
x = F.leaky_relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action, max_value=torch.max(x))
return dict(values=x)
# Model dict is only used in evaluation but not training
@ -438,9 +384,17 @@ class General_Model:
self.models['landlord_down'] = GeneralModel1().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None, debug=False):
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
model = self.models[position]
return model.forward(z, x, training, flags, debug)
values = model.forward(z, x)['values']
if return_value:
return dict(values=values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(values.shape[0], (1,))[0]
else:
action = torch.argmax(values,dim=0)[0]
return dict(action=action, max_value=torch.max(values))
def share_memory(self):
self.models['landlord'].share_memory()
@ -480,9 +434,17 @@ class OldModel:
self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None):
def forward(self, position, z, x, return_value=False, flags=None):
model = self.models[position]
return model.forward(z, x, training, flags)
values = model.forward(z, x)['values']
if return_value:
return dict(values=values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(values.shape[0], (1,))[0]
else:
action = torch.argmax(values,dim=0)[0]
return dict(action=action, max_value=torch.max(values))
def share_memory(self):
self.models['landlord'].share_memory()
@ -523,10 +485,35 @@ class Model:
self.models['landlord_front'] = GeneralModel().to(torch.device(device))
self.models['landlord_down'] = GeneralModel().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
self.onnx_models = {
'landlord': None,
'landlord_up': None,
'landlord_front': None,
'landlord_down': None,
'bidding': None
}
self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None, debug=False):
model = self.models[position]
return model.forward(z, x, training, flags, debug)
def set_onnx_model(self, position, model_path):
self.onnx_models[position] = get_example(model_path)
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
model = self.onnx_models[position]
if model is None:
model = self.models[position]
values = model.forward(z, x)['values']
else:
sess = onnxruntime.InferenceSession(model)
onnx_out = sess.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
values = torch.tensor(onnx_out[0])
if return_value:
return dict(values=values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(values.shape[0], (1,))[0]
else:
action = torch.argmax(values,dim=0)[0]
return dict(action=action, max_value=torch.max(values))
def share_memory(self):
self.models['landlord'].share_memory()

View File

@ -81,7 +81,7 @@ def create_optimizers(flags, learner_model):
return optimizers
def act(i, device, batch_queues, model, flags):
def act(i, device, batch_queues, model, flags, onnx_frame):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']
for pos in positions:
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
@ -110,8 +110,16 @@ def act(i, device, batch_queues, model, flags):
position, obs, env_output = env.initial(model, device, flags=flags)
bid_obs_buffer = env_output["begin_buf"]["bid_obs_buffer"]
multiply_obs_buffer = env_output["begin_buf"]["multiply_obs_buffer"]
last_onnx_frame = -1
while True:
# print("posi", position)
if onnx_frame != last_onnx_frame:
last_onnx_frame = onnx_frame
for p in positions:
if p != 'bidding':
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, p)
model.set_onnx_model(p, os.path.abspath(model_path))
for bid_obs in bid_obs_buffer:
obs_z_buf["bidding"].append(bid_obs['z_batch'])
obs_x_batch_buf["bidding"].append(bid_obs["x_batch"])

View File

@ -61,8 +61,8 @@ if __name__ == '__main__':
parser.add_argument('--bid', type=bool, default=True)
parser.add_argument('--title', type=str, default='New')
args = parser.parse_args()
# args.output = True
args.output = False
args.output = True
# args.output = False
args.bid = False
if args.output or args.bid:
args.num_workers = 1
@ -72,15 +72,30 @@ if __name__ == '__main__':
eval_list = [
# {
# 'landlord': { 'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 48545600},
# 'farmer': { 'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400},
# 'two_way': True
# 'landlord': {'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 143539200},
# 'farmer': {'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 143539200},
# 'two_way': False
# },
{
'landlord': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400},
'farmer': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400},
'farmer': { 'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 143539200},
'landlord': { 'folder': 'baselines', 'prefix': 'resnet', 'frame': 23358800},
'two_way': True
}
},
# {
# 'landlord': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400},
# 'farmer': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400},
# 'two_way': False
# },
# {
# 'landlord': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400},
# 'farmer': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400},
# 'two_way': False
# },
# {
# 'landlord': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400},
# 'farmer': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400},
# 'two_way': True
# },
]
for vs in reversed(eval_list):

View File

@ -15,8 +15,8 @@ deck.extend([20, 20, 30, 30])
def get_parser():
parser = argparse.ArgumentParser(description='DouZero: random data generator')
parser.add_argument('--output', default='eval_data_200', type=str)
parser.add_argument('--path', default='baselines/resnet_bidding_27853600.ckpt', type=str)
parser.add_argument('--output', default='eval_data_500', type=str)
parser.add_argument('--path', default='baselines/resnet_landlord_23358800.ckpt', type=str)
parser.add_argument('--num_games', default=200, type=int)
parser.add_argument('--exp_epsilon', default=0.01, type=float)
return parser

View File

@ -3,3 +3,5 @@ GitPython
gitdb2
rlcard
psutil
onnx
onnxruntime