修复训练vanilla模型的BUG
This commit is contained in:
parent
42f28b1aa9
commit
a4d2a93afc
|
@ -104,7 +104,7 @@ def train(flags):
|
|||
# Initialize queues
|
||||
actor_processes = []
|
||||
ctx = mp.get_context('spawn')
|
||||
batch_queues = {"landlord": ctx.Queue(flags.unroll_length * 4), "landlord_up": ctx.Queue(flags.unroll_length * 4), 'landlord_front': ctx.Queue(flags.unroll_length * 4), "landlord_down": ctx.Queue(flags.unroll_length * 4)}
|
||||
batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue()}
|
||||
onnx_frame = ctx.Value('d', -1)
|
||||
|
||||
# Learner model for training
|
||||
|
@ -268,7 +268,7 @@ def train(flags):
|
|||
|
||||
fps = (frames - start_frames) / (end_time - start_time)
|
||||
fps_log.append(fps)
|
||||
if len(fps_log) > 24:
|
||||
if len(fps_log) > 240:
|
||||
fps_log = fps_log[1:]
|
||||
fps_avg = np.mean(fps_log)
|
||||
|
||||
|
|
|
@ -330,22 +330,34 @@ class OldModel:
|
|||
"""
|
||||
def __init__(self, device=0, flags=None):
|
||||
self.models = {}
|
||||
self.onnx_models = {}
|
||||
self.flags = flags
|
||||
if not device == "cpu":
|
||||
device = 'cuda:' + str(device)
|
||||
self.device = torch.device(device)
|
||||
self.models['landlord'] = LandlordLstmModel().to(self.device)
|
||||
self.models['landlord_up'] = FarmerLstmModel().to(self.device)
|
||||
self.models['landlord_front'] = FarmerLstmModel().to(self.device)
|
||||
self.models['landlord_down'] = FarmerLstmModel().to(self.device)
|
||||
self.onnx_models = {
|
||||
'landlord': None,
|
||||
'landlord_up': None,
|
||||
'landlord_front': None,
|
||||
'landlord_down': None
|
||||
}
|
||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
if flags is not None and flags.enable_onnx:
|
||||
for position in positions:
|
||||
self.models[position] = None
|
||||
else:
|
||||
for position in positions[1:]:
|
||||
self.models[position] = FarmerLstmModel().to(self.device)
|
||||
self.models['landlord'] = LandlordLstmModel().to(self.device)
|
||||
self.onnx_models = {
|
||||
'landlord': None,
|
||||
'landlord_up': None,
|
||||
'landlord_front': None,
|
||||
'landlord_down': None
|
||||
}
|
||||
|
||||
def set_onnx_model(self, position, model_path):
|
||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
|
||||
def set_onnx_model(self, device='cpu'):
|
||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
for position in positions:
|
||||
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid, position))
|
||||
if device == 'cpu':
|
||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
||||
else:
|
||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
|
||||
|
||||
def get_onnx_params(self, position):
|
||||
self.models[position].get_onnx_params(self.device)
|
||||
|
@ -356,16 +368,22 @@ class OldModel:
|
|||
model = self.models[position]
|
||||
values = model.forward(z, x)['values']
|
||||
else:
|
||||
onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
|
||||
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
||||
values = torch.tensor(onnx_out[0])
|
||||
if return_value:
|
||||
return dict(values=values)
|
||||
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
|
||||
else:
|
||||
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
|
||||
action = torch.randint(values.shape[0], (1,))[0]
|
||||
if flags.enable_onnx:
|
||||
action = np.random.randint(0, values.shape[0], (1,))[0]
|
||||
else:
|
||||
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
||||
else:
|
||||
action = torch.argmax(values,dim=0)[0]
|
||||
return dict(action=action)
|
||||
if flags.enable_onnx:
|
||||
action = np.argmax(values, axis=0)[0].cpu().detach().numpy()
|
||||
else:
|
||||
action = torch.argmax(values,dim=0)[0]
|
||||
return dict(action = action)
|
||||
|
||||
def share_memory(self):
|
||||
if self.models['landlord'] is not None:
|
||||
|
|
Loading…
Reference in New Issue