调整封装
This commit is contained in:
parent
6fa590a697
commit
3b6ffb2753
|
@ -33,6 +33,8 @@ parser.add_argument('--disable_checkpoint', action='store_true',
|
||||||
help='Disable saving checkpoint')
|
help='Disable saving checkpoint')
|
||||||
parser.add_argument('--savedir', default='douzero_checkpoints',
|
parser.add_argument('--savedir', default='douzero_checkpoints',
|
||||||
help='Root dir where experiment data will be saved')
|
help='Root dir where experiment data will be saved')
|
||||||
|
parser.add_argument('--enable_onnx', action='store_true',
|
||||||
|
help='Use onnx model for train')
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
parser.add_argument('--total_frames', default=100000000000, type=int,
|
parser.add_argument('--total_frames', default=100000000000, type=int,
|
||||||
|
|
|
@ -229,30 +229,20 @@ def train(flags):
|
||||||
}, checkpointpath + '.new')
|
}, checkpointpath + '.new')
|
||||||
|
|
||||||
# Save the weights for evaluation purpose
|
# Save the weights for evaluation purpose
|
||||||
dummy_input = (
|
|
||||||
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32),
|
|
||||||
torch.tensor(np.zeros((1, 80)), dtype=torch.float32)
|
|
||||||
)
|
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']:
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']:
|
||||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
||||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||||
if position != 'bidding':
|
if flags.enable_onnx and position != 'bidding':
|
||||||
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
|
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
|
||||||
|
onnx_params = learner_model.get_model(position).get_onnx_params()
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
learner_model.get_model(position),
|
learner_model.get_model(position),
|
||||||
dummy_input,
|
onnx_params['args'],
|
||||||
model_path,
|
model_path,
|
||||||
input_names=['z_batch','x_batch'],
|
input_names=onnx_params['input_names'],
|
||||||
output_names=['values'],
|
output_names=onnx_params['output_names'],
|
||||||
dynamic_axes={
|
dynamic_axes=onnx_params['dynamic_axes']
|
||||||
'z_batch': {
|
|
||||||
0: "legal_actions"
|
|
||||||
},
|
|
||||||
'x_batch': {
|
|
||||||
0: "legal_actions"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
onnx_frame.value = frames
|
onnx_frame.value = frames
|
||||||
shutil.move(checkpointpath + '.new', checkpointpath)
|
shutil.move(checkpointpath + '.new', checkpointpath)
|
||||||
|
|
|
@ -25,6 +25,24 @@ class LandlordLstmModel(nn.Module):
|
||||||
self.dense5 = nn.Linear(512, 256)
|
self.dense5 = nn.Linear(512, 256)
|
||||||
self.dense6 = nn.Linear(256, 1)
|
self.dense6 = nn.Linear(256, 1)
|
||||||
|
|
||||||
|
def get_onnx_params(self):
|
||||||
|
return {
|
||||||
|
'args': (
|
||||||
|
torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32),
|
||||||
|
torch.tensor(np.zeros((1, 887)), dtype=torch.float32),
|
||||||
|
),
|
||||||
|
'input_names': ['z_batch','x_batch'],
|
||||||
|
'output_names': ['values'],
|
||||||
|
'dynamic_axes': {
|
||||||
|
'z_batch': {
|
||||||
|
0: "legal_actions"
|
||||||
|
},
|
||||||
|
'x_batch': {
|
||||||
|
0: "legal_actions"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def forward(self, z, x):
|
def forward(self, z, x):
|
||||||
lstm_out, (h_n, _) = self.lstm(z)
|
lstm_out, (h_n, _) = self.lstm(z)
|
||||||
lstm_out = lstm_out[:,-1,:]
|
lstm_out = lstm_out[:,-1,:]
|
||||||
|
@ -53,6 +71,24 @@ class FarmerLstmModel(nn.Module):
|
||||||
self.dense5 = nn.Linear(512, 256)
|
self.dense5 = nn.Linear(512, 256)
|
||||||
self.dense6 = nn.Linear(256, 1)
|
self.dense6 = nn.Linear(256, 1)
|
||||||
|
|
||||||
|
def get_onnx_params(self):
|
||||||
|
return {
|
||||||
|
'args': (
|
||||||
|
torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32),
|
||||||
|
torch.tensor(np.zeros((1, 1219)), dtype=torch.float32)
|
||||||
|
),
|
||||||
|
'input_names': ['z_batch','x_batch'],
|
||||||
|
'output_names': ['values'],
|
||||||
|
'dynamic_axes': {
|
||||||
|
'z_batch': {
|
||||||
|
0: "legal_actions"
|
||||||
|
},
|
||||||
|
'x_batch': {
|
||||||
|
0: "legal_actions"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def forward(self, z, x):
|
def forward(self, z, x):
|
||||||
lstm_out, (h_n, _) = self.lstm(z)
|
lstm_out, (h_n, _) = self.lstm(z)
|
||||||
lstm_out = lstm_out[:,-1,:]
|
lstm_out = lstm_out[:,-1,:]
|
||||||
|
@ -292,6 +328,24 @@ class GeneralModel(nn.Module):
|
||||||
self.in_planes = planes * block.expansion
|
self.in_planes = planes * block.expansion
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def get_onnx_params(self):
|
||||||
|
return {
|
||||||
|
'args': (
|
||||||
|
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32),
|
||||||
|
torch.tensor(np.zeros((1, 80)), dtype=torch.float32)
|
||||||
|
),
|
||||||
|
'input_names': ['z_batch','x_batch'],
|
||||||
|
'output_names': ['values'],
|
||||||
|
'dynamic_axes': {
|
||||||
|
'z_batch': {
|
||||||
|
0: "legal_actions"
|
||||||
|
},
|
||||||
|
'x_batch': {
|
||||||
|
0: "legal_actions"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def forward(self, z, x):
|
def forward(self, z, x):
|
||||||
out = F.relu(self.bn1(self.conv1(z)))
|
out = F.relu(self.bn1(self.conv1(z)))
|
||||||
out = self.layer1(out)
|
out = self.layer1(out)
|
||||||
|
@ -433,10 +487,28 @@ class OldModel:
|
||||||
self.models['landlord_front'] = FarmerLstmModel().to(torch.device(device))
|
self.models['landlord_front'] = FarmerLstmModel().to(torch.device(device))
|
||||||
self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device))
|
self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device))
|
||||||
self.models['bidding'] = BidModel().to(torch.device(device))
|
self.models['bidding'] = BidModel().to(torch.device(device))
|
||||||
|
self.onnx_models = {
|
||||||
|
'landlord': None,
|
||||||
|
'landlord_up': None,
|
||||||
|
'landlord_front': None,
|
||||||
|
'landlord_down': None,
|
||||||
|
'bidding': None
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_onnx_model(self, position, model_path):
|
||||||
|
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
|
||||||
|
|
||||||
|
def get_onnx_params(self, position):
|
||||||
|
self.models[position].get_onnx_params()
|
||||||
|
|
||||||
def forward(self, position, z, x, return_value=False, flags=None):
|
def forward(self, position, z, x, return_value=False, flags=None):
|
||||||
|
model = self.onnx_models[position]
|
||||||
|
if model is None:
|
||||||
model = self.models[position]
|
model = self.models[position]
|
||||||
values = model.forward(z, x)['values']
|
values = model.forward(z, x)['values']
|
||||||
|
else:
|
||||||
|
onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
|
||||||
|
values = torch.tensor(onnx_out[0])
|
||||||
if return_value:
|
if return_value:
|
||||||
return dict(values=values)
|
return dict(values=values)
|
||||||
else:
|
else:
|
||||||
|
@ -497,6 +569,9 @@ class Model:
|
||||||
def set_onnx_model(self, position, model_path):
|
def set_onnx_model(self, position, model_path):
|
||||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
|
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
|
||||||
|
|
||||||
|
def get_onnx_params(self, position):
|
||||||
|
self.models[position].get_onnx_params()
|
||||||
|
|
||||||
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
||||||
model = self.onnx_models[position]
|
model = self.onnx_models[position]
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -113,7 +113,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
||||||
last_onnx_frame = -1
|
last_onnx_frame = -1
|
||||||
while True:
|
while True:
|
||||||
# print("posi", position)
|
# print("posi", position)
|
||||||
if onnx_frame.value != last_onnx_frame:
|
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
|
||||||
last_onnx_frame = onnx_frame.value
|
last_onnx_frame = onnx_frame.value
|
||||||
for p in positions:
|
for p in positions:
|
||||||
if p != 'bidding':
|
if p != 'bidding':
|
||||||
|
|
Loading…
Reference in New Issue