导出onnx模型逻辑,中间状态

This commit is contained in:
zhiyang7 2021-12-14 18:21:01 +08:00
parent 7f0e494109
commit f054fed61c
1 changed files with 15 additions and 0 deletions

View File

@ -228,10 +228,25 @@ def train(flags):
}, checkpointpath + '.new') }, checkpointpath + '.new')
# Save the weights for evaluation purpose # Save the weights for evaluation purpose
dummy_input = (torch.tensor(np.zeros([80, 40, 108]), dtype=torch.float32),
torch.tensor(np.zeros((80, 80)), dtype=torch.int8),
{
'return_value': False,
'flags': {'exp_epsilon':0.001}
},
)
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
model_weights_dir = os.path.expandvars(os.path.expanduser( model_weights_dir = os.path.expandvars(os.path.expanduser(
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt'))) '%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt')))
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
if position != 'bidding':
torch.onnx.export(
learner_model.get_model(position),
dummy_input,
'%s/model_%s.onnx' % (flags.savedir, position),
input_names=['z_batch','x_batch','flags'],
output_names=['action', 'max_value']
)
shutil.move(checkpointpath + '.new', checkpointpath) shutil.move(checkpointpath + '.new', checkpointpath)