调整封装
This commit is contained in:
parent
6fa590a697
commit
3b6ffb2753
|
@ -33,6 +33,8 @@ parser.add_argument('--disable_checkpoint', action='store_true',
|
|||
help='Disable saving checkpoint')
|
||||
parser.add_argument('--savedir', default='douzero_checkpoints',
|
||||
help='Root dir where experiment data will be saved')
|
||||
parser.add_argument('--enable_onnx', action='store_true',
|
||||
help='Use onnx model for train')
|
||||
|
||||
# Hyperparameters
|
||||
parser.add_argument('--total_frames', default=100000000000, type=int,
|
||||
|
|
|
@ -229,30 +229,20 @@ def train(flags):
|
|||
}, checkpointpath + '.new')
|
||||
|
||||
# Save the weights for evaluation purpose
|
||||
dummy_input = (
|
||||
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32),
|
||||
torch.tensor(np.zeros((1, 80)), dtype=torch.float32)
|
||||
)
|
||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']:
|
||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||
if position != 'bidding':
|
||||
if flags.enable_onnx and position != 'bidding':
|
||||
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
|
||||
onnx_params = learner_model.get_model(position).get_onnx_params()
|
||||
torch.onnx.export(
|
||||
learner_model.get_model(position),
|
||||
dummy_input,
|
||||
onnx_params['args'],
|
||||
model_path,
|
||||
input_names=['z_batch','x_batch'],
|
||||
output_names=['values'],
|
||||
dynamic_axes={
|
||||
'z_batch': {
|
||||
0: "legal_actions"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "legal_actions"
|
||||
}
|
||||
}
|
||||
input_names=onnx_params['input_names'],
|
||||
output_names=onnx_params['output_names'],
|
||||
dynamic_axes=onnx_params['dynamic_axes']
|
||||
)
|
||||
onnx_frame.value = frames
|
||||
shutil.move(checkpointpath + '.new', checkpointpath)
|
||||
|
|
|
@ -25,6 +25,24 @@ class LandlordLstmModel(nn.Module):
|
|||
self.dense5 = nn.Linear(512, 256)
|
||||
self.dense6 = nn.Linear(256, 1)
|
||||
|
||||
def get_onnx_params(self):
|
||||
return {
|
||||
'args': (
|
||||
torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32),
|
||||
torch.tensor(np.zeros((1, 887)), dtype=torch.float32),
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
'dynamic_axes': {
|
||||
'z_batch': {
|
||||
0: "legal_actions"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "legal_actions"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def forward(self, z, x):
|
||||
lstm_out, (h_n, _) = self.lstm(z)
|
||||
lstm_out = lstm_out[:,-1,:]
|
||||
|
@ -53,6 +71,24 @@ class FarmerLstmModel(nn.Module):
|
|||
self.dense5 = nn.Linear(512, 256)
|
||||
self.dense6 = nn.Linear(256, 1)
|
||||
|
||||
def get_onnx_params(self):
|
||||
return {
|
||||
'args': (
|
||||
torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32),
|
||||
torch.tensor(np.zeros((1, 1219)), dtype=torch.float32)
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
'dynamic_axes': {
|
||||
'z_batch': {
|
||||
0: "legal_actions"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "legal_actions"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def forward(self, z, x):
|
||||
lstm_out, (h_n, _) = self.lstm(z)
|
||||
lstm_out = lstm_out[:,-1,:]
|
||||
|
@ -292,6 +328,24 @@ class GeneralModel(nn.Module):
|
|||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_onnx_params(self):
|
||||
return {
|
||||
'args': (
|
||||
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32),
|
||||
torch.tensor(np.zeros((1, 80)), dtype=torch.float32)
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
'dynamic_axes': {
|
||||
'z_batch': {
|
||||
0: "legal_actions"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "legal_actions"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def forward(self, z, x):
|
||||
out = F.relu(self.bn1(self.conv1(z)))
|
||||
out = self.layer1(out)
|
||||
|
@ -433,10 +487,28 @@ class OldModel:
|
|||
self.models['landlord_front'] = FarmerLstmModel().to(torch.device(device))
|
||||
self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device))
|
||||
self.models['bidding'] = BidModel().to(torch.device(device))
|
||||
self.onnx_models = {
|
||||
'landlord': None,
|
||||
'landlord_up': None,
|
||||
'landlord_front': None,
|
||||
'landlord_down': None,
|
||||
'bidding': None
|
||||
}
|
||||
|
||||
def set_onnx_model(self, position, model_path):
|
||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
|
||||
|
||||
def get_onnx_params(self, position):
|
||||
self.models[position].get_onnx_params()
|
||||
|
||||
def forward(self, position, z, x, return_value=False, flags=None):
|
||||
model = self.models[position]
|
||||
values = model.forward(z, x)['values']
|
||||
model = self.onnx_models[position]
|
||||
if model is None:
|
||||
model = self.models[position]
|
||||
values = model.forward(z, x)['values']
|
||||
else:
|
||||
onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
|
||||
values = torch.tensor(onnx_out[0])
|
||||
if return_value:
|
||||
return dict(values=values)
|
||||
else:
|
||||
|
@ -497,6 +569,9 @@ class Model:
|
|||
def set_onnx_model(self, position, model_path):
|
||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
|
||||
|
||||
def get_onnx_params(self, position):
|
||||
self.models[position].get_onnx_params()
|
||||
|
||||
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
||||
model = self.onnx_models[position]
|
||||
if model is None:
|
||||
|
|
|
@ -113,7 +113,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
|||
last_onnx_frame = -1
|
||||
while True:
|
||||
# print("posi", position)
|
||||
if onnx_frame.value != last_onnx_frame:
|
||||
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
|
||||
last_onnx_frame = onnx_frame.value
|
||||
for p in positions:
|
||||
if p != 'bidding':
|
||||
|
|
Loading…
Reference in New Issue