修复BUG

This commit is contained in:
zhiyang7 2021-12-15 10:03:26 +08:00
parent 9f2e8f74f3
commit 6fa590a697
4 changed files with 7 additions and 10 deletions

View File

@ -19,7 +19,6 @@ import psutil
import shutil import shutil
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']} mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']}
onnx_frame = mp.Value('d', -1)
def compute_loss(logits, targets): def compute_loss(logits, targets):
loss = ((logits.squeeze(-1) - targets)**2).mean() loss = ((logits.squeeze(-1) - targets)**2).mean()
@ -80,7 +79,6 @@ def train(flags):
Then it will start subprocesses as actors. Then, it will call Then it will start subprocesses as actors. Then, it will call
learning function with multiple threads. learning function with multiple threads.
""" """
global onnx_frame
if not flags.actor_device_cpu or flags.training_device != 'cpu': if not flags.actor_device_cpu or flags.training_device != 'cpu':
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`") raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`")
@ -116,6 +114,7 @@ def train(flags):
actor_processes = [] actor_processes = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue(), "bidding": ctx.SimpleQueue()} batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue(), "bidding": ctx.SimpleQueue()}
onnx_frame = ctx.Value('d', -1)
# Learner model for training # Learner model for training
if flags.old_model: if flags.old_model:
@ -216,7 +215,6 @@ def train(flags):
threads.append(thread) threads.append(thread)
def checkpoint(frames): def checkpoint(frames):
global onnx_frame
if flags.disable_checkpoint: if flags.disable_checkpoint:
return return
log.info('Saving checkpoint to %s', checkpointpath) log.info('Saving checkpoint to %s', checkpointpath)
@ -256,7 +254,7 @@ def train(flags):
} }
} }
) )
onnx_frame = frames onnx_frame.value = frames
shutil.move(checkpointpath + '.new', checkpointpath) shutil.move(checkpointpath + '.new', checkpointpath)

View File

@ -495,7 +495,7 @@ class Model:
self.models['bidding'] = BidModel().to(torch.device(device)) self.models['bidding'] = BidModel().to(torch.device(device))
def set_onnx_model(self, position, model_path): def set_onnx_model(self, position, model_path):
self.onnx_models[position] = get_example(model_path) self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
def forward(self, position, z, x, return_value=False, flags=None, debug=False): def forward(self, position, z, x, return_value=False, flags=None, debug=False):
model = self.onnx_models[position] model = self.onnx_models[position]
@ -503,8 +503,7 @@ class Model:
model = self.models[position] model = self.models[position]
values = model.forward(z, x)['values'] values = model.forward(z, x)['values']
else: else:
sess = onnxruntime.InferenceSession(model) onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
onnx_out = sess.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
values = torch.tensor(onnx_out[0]) values = torch.tensor(onnx_out[0])
if return_value: if return_value:
return dict(values=values) return dict(values=values)

View File

@ -113,8 +113,8 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
last_onnx_frame = -1 last_onnx_frame = -1
while True: while True:
# print("posi", position) # print("posi", position)
if onnx_frame != last_onnx_frame: if onnx_frame.value != last_onnx_frame:
last_onnx_frame = onnx_frame last_onnx_frame = onnx_frame.value
for p in positions: for p in positions:
if p != 'bidding': if p != 'bidding':
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, p) model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, p)

View File

@ -49,7 +49,7 @@ class DeepAgent:
x_batch = torch.from_numpy(obs['x_batch']).float() x_batch = torch.from_numpy(obs['x_batch']).float()
if torch.cuda.is_available(): if torch.cuda.is_available():
z_batch, x_batch = z_batch.cuda(), x_batch.cuda() z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
y_pred = self.model.forward(z_batch, x_batch, return_value=True)['values'] y_pred = self.model.forward(z_batch, x_batch)['values']
y_pred = y_pred.detach().cpu().numpy() y_pred = y_pred.detach().cpu().numpy()
best_action_index = np.argmax(y_pred, axis=0)[0] best_action_index = np.argmax(y_pred, axis=0)[0]