From f054fed61cc1c56a95572781ac559c34863e3bcf Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Tue, 14 Dec 2021 18:21:01 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AF=BC=E5=87=BAonnx=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E4=B8=AD=E9=97=B4=E7=8A=B6=E6=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/dmc.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index f0808fa..bef2003 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -228,10 +228,25 @@ def train(flags): }, checkpointpath + '.new') # Save the weights for evaluation purpose + dummy_input = (torch.tensor(np.zeros([80, 40, 108]), dtype=torch.float32), + torch.tensor(np.zeros((80, 80)), dtype=torch.int8), + { + 'return_value': False, + 'flags': {'exp_epsilon':0.001} + }, + ) for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] model_weights_dir = os.path.expandvars(os.path.expanduser( '%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt'))) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) + if position != 'bidding': + torch.onnx.export( + learner_model.get_model(position), + dummy_input, + '%s/model_%s.onnx' % (flags.savedir, position), + input_names=['z_batch','x_batch','flags'], + output_names=['action', 'max_value'] + ) shutil.move(checkpointpath + '.new', checkpointpath)