导出onnx模型逻辑,中间状态
This commit is contained in:
parent
7f0e494109
commit
f054fed61c
|
@ -228,10 +228,25 @@ def train(flags):
|
||||||
}, checkpointpath + '.new')
|
}, checkpointpath + '.new')
|
||||||
|
|
||||||
# Save the weights for evaluation purpose
|
# Save the weights for evaluation purpose
|
||||||
|
dummy_input = (torch.tensor(np.zeros([80, 40, 108]), dtype=torch.float32),
|
||||||
|
torch.tensor(np.zeros((80, 80)), dtype=torch.int8),
|
||||||
|
{
|
||||||
|
'return_value': False,
|
||||||
|
'flags': {'exp_epsilon':0.001}
|
||||||
|
},
|
||||||
|
)
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt')))
|
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt')))
|
||||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||||
|
if position != 'bidding':
|
||||||
|
torch.onnx.export(
|
||||||
|
learner_model.get_model(position),
|
||||||
|
dummy_input,
|
||||||
|
'%s/model_%s.onnx' % (flags.savedir, position),
|
||||||
|
input_names=['z_batch','x_batch','flags'],
|
||||||
|
output_names=['action', 'max_value']
|
||||||
|
)
|
||||||
shutil.move(checkpointpath + '.new', checkpointpath)
|
shutil.move(checkpointpath + '.new', checkpointpath)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue