diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index f0808fa..bef2003 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -228,10 +228,25 @@ def train(flags): }, checkpointpath + '.new') # Save the weights for evaluation purpose + dummy_input = (torch.tensor(np.zeros([80, 40, 108]), dtype=torch.float32), + torch.tensor(np.zeros((80, 80)), dtype=torch.int8), + { + 'return_value': False, + 'flags': {'exp_epsilon':0.001} + }, + ) for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] model_weights_dir = os.path.expandvars(os.path.expanduser( '%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt'))) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) + if position != 'bidding': + torch.onnx.export( + learner_model.get_model(position), + dummy_input, + '%s/model_%s.onnx' % (flags.savedir, position), + input_names=['z_batch','x_batch','flags'], + output_names=['action', 'max_value'] + ) shutil.move(checkpointpath + '.new', checkpointpath)