调整封装

This commit is contained in:
zhiyang7 2021-12-15 12:26:47 +08:00
parent 6fa590a697
commit 3b6ffb2753
4 changed files with 86 additions and 19 deletions

View File

@ -33,6 +33,8 @@ parser.add_argument('--disable_checkpoint', action='store_true',
help='Disable saving checkpoint')
parser.add_argument('--savedir', default='douzero_checkpoints',
help='Root dir where experiment data will be saved')
parser.add_argument('--enable_onnx', action='store_true',
help='Use onnx model for train')
# Hyperparameters
parser.add_argument('--total_frames', default=100000000000, type=int,

View File

@ -229,30 +229,20 @@ def train(flags):
}, checkpointpath + '.new')
# Save the weights for evaluation purpose
dummy_input = (
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32),
torch.tensor(np.zeros((1, 80)), dtype=torch.float32)
)
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']:
model_weights_dir = os.path.expandvars(os.path.expanduser(
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
if position != 'bidding':
if flags.enable_onnx and position != 'bidding':
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
onnx_params = learner_model.get_model(position).get_onnx_params()
torch.onnx.export(
learner_model.get_model(position),
dummy_input,
onnx_params['args'],
model_path,
input_names=['z_batch','x_batch'],
output_names=['values'],
dynamic_axes={
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
}
}
input_names=onnx_params['input_names'],
output_names=onnx_params['output_names'],
dynamic_axes=onnx_params['dynamic_axes']
)
onnx_frame.value = frames
shutil.move(checkpointpath + '.new', checkpointpath)

View File

@ -25,6 +25,24 @@ class LandlordLstmModel(nn.Module):
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self):
return {
'args': (
torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32),
torch.tensor(np.zeros((1, 887)), dtype=torch.float32),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
@ -53,6 +71,24 @@ class FarmerLstmModel(nn.Module):
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self):
return {
'args': (
torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32),
torch.tensor(np.zeros((1, 1219)), dtype=torch.float32)
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
@ -292,6 +328,24 @@ class GeneralModel(nn.Module):
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def get_onnx_params(self):
return {
'args': (
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32),
torch.tensor(np.zeros((1, 80)), dtype=torch.float32)
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
}
}
}
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
@ -433,10 +487,28 @@ class OldModel:
self.models['landlord_front'] = FarmerLstmModel().to(torch.device(device))
self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
self.onnx_models = {
'landlord': None,
'landlord_up': None,
'landlord_front': None,
'landlord_down': None,
'bidding': None
}
def set_onnx_model(self, position, model_path):
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
def get_onnx_params(self, position):
self.models[position].get_onnx_params()
def forward(self, position, z, x, return_value=False, flags=None):
model = self.onnx_models[position]
if model is None:
model = self.models[position]
values = model.forward(z, x)['values']
else:
onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
values = torch.tensor(onnx_out[0])
if return_value:
return dict(values=values)
else:
@ -497,6 +569,9 @@ class Model:
def set_onnx_model(self, position, model_path):
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
def get_onnx_params(self, position):
self.models[position].get_onnx_params()
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
model = self.onnx_models[position]
if model is None:

View File

@ -113,7 +113,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
last_onnx_frame = -1
while True:
# print("posi", position)
if onnx_frame.value != last_onnx_frame:
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
last_onnx_frame = onnx_frame.value
for p in positions:
if p != 'bidding':