导出onnx模型逻辑,中间状态
This commit is contained in:
parent
7f0e494109
commit
f054fed61c
|
@ -228,10 +228,25 @@ def train(flags):
|
|||
}, checkpointpath + '.new')
|
||||
|
||||
# Save the weights for evaluation purpose
|
||||
dummy_input = (torch.tensor(np.zeros([80, 40, 108]), dtype=torch.float32),
|
||||
torch.tensor(np.zeros((80, 80)), dtype=torch.int8),
|
||||
{
|
||||
'return_value': False,
|
||||
'flags': {'exp_epsilon':0.001}
|
||||
},
|
||||
)
|
||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt')))
|
||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||
if position != 'bidding':
|
||||
torch.onnx.export(
|
||||
learner_model.get_model(position),
|
||||
dummy_input,
|
||||
'%s/model_%s.onnx' % (flags.savedir, position),
|
||||
input_names=['z_batch','x_batch','flags'],
|
||||
output_names=['action', 'max_value']
|
||||
)
|
||||
shutil.move(checkpointpath + '.new', checkpointpath)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue