使用onnx进行infer逻辑,未完成

This commit is contained in:
ZaneYork 2021-12-14 23:06:07 +08:00
parent 0cb3d040cb
commit 9f2e8f74f3
2 changed files with 3 additions and 2 deletions

View File

@ -53,7 +53,7 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock):
if len(episode_returns) > 0: if len(episode_returns) > 0:
mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device)) mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device))
with lock: with lock:
learner_outputs = model(obs_z, obs_x, return_value=True) learner_outputs = model(obs_z, obs_x)
if position == "bidding": if position == "bidding":
loss = compute_loss(learner_outputs['values'], target) loss = compute_loss(learner_outputs['values'], target)
# pass # pass

View File

@ -118,7 +118,8 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
for p in positions: for p in positions:
if p != 'bidding': if p != 'bidding':
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, p) model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, p)
model.set_onnx_model(p, os.path.abspath(model_path)) if os.path.exists(model_path):
model.set_onnx_model(p, os.path.abspath(model_path))
for bid_obs in bid_obs_buffer: for bid_obs in bid_obs_buffer:
obs_z_buf["bidding"].append(bid_obs['z_batch']) obs_z_buf["bidding"].append(bid_obs['z_batch'])