修复bug
This commit is contained in:
parent
8dd95e3e59
commit
0f8cd23c20
|
@ -120,6 +120,10 @@ def train(flags):
|
||||||
]
|
]
|
||||||
frames, stats = 0, {k: 0 for k in stat_keys}
|
frames, stats = 0, {k: 0 for k in stat_keys}
|
||||||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
||||||
|
if flags.unified_model:
|
||||||
|
lock = threading.Lock()
|
||||||
|
position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock}
|
||||||
|
else:
|
||||||
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||||
|
|
||||||
def sync_onnx_model(frames):
|
def sync_onnx_model(frames):
|
||||||
|
@ -153,6 +157,12 @@ def train(flags):
|
||||||
checkpoint_states = torch.load(
|
checkpoint_states = torch.load(
|
||||||
checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu")
|
checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu")
|
||||||
)
|
)
|
||||||
|
if flags.unified_model:
|
||||||
|
learner_model.get_model('uni').load_state_dict(checkpoint_states["model_state_dict"]['uni'])
|
||||||
|
optimizers['uni'].load_state_dict(checkpoint_states["optimizer_state_dict"]['uni'])
|
||||||
|
if not flags.enable_onnx:
|
||||||
|
actor_model.get_model('uni').load_state_dict(checkpoint_states["model_state_dict"]['uni'])
|
||||||
|
else:
|
||||||
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down']
|
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down']
|
||||||
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||||
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
||||||
|
@ -210,7 +220,7 @@ def train(flags):
|
||||||
while frames < flags.total_frames:
|
while frames < flags.total_frames:
|
||||||
batch = get_batch(batch_queues, position, flags, local_lock)
|
batch = get_batch(batch_queues, position, flags, local_lock)
|
||||||
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
|
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
|
||||||
optimizers[position], flags, position_lock)
|
optimizers['uni'], flags, position_lock)
|
||||||
with lock:
|
with lock:
|
||||||
for k in _stats:
|
for k in _stats:
|
||||||
stats[k] = _stats[k]
|
stats[k] = _stats[k]
|
||||||
|
@ -248,7 +258,7 @@ def train(flags):
|
||||||
}, checkpointpath + '.new')
|
}, checkpointpath + '.new')
|
||||||
|
|
||||||
# Save the weights for evaluation purpose
|
# Save the weights for evaluation purpose
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
for position in ['uni'] if flags.unified_model else ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
||||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||||
|
|
|
@ -104,6 +104,14 @@ def create_optimizers(flags, learner_model):
|
||||||
"""
|
"""
|
||||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
optimizers = {}
|
optimizers = {}
|
||||||
|
if flags.unified_model:
|
||||||
|
position = 'uni'
|
||||||
|
optimizer = RAdam(
|
||||||
|
learner_model.parameters(position),
|
||||||
|
lr=flags.learning_rate,
|
||||||
|
eps=flags.epsilon)
|
||||||
|
optimizers[position] = optimizer
|
||||||
|
else:
|
||||||
for position in positions:
|
for position in positions:
|
||||||
optimizer = RAdam(
|
optimizer = RAdam(
|
||||||
learner_model.parameters(position),
|
learner_model.parameters(position),
|
||||||
|
|
Loading…
Reference in New Issue