修复bug
This commit is contained in:
parent
8dd95e3e59
commit
0f8cd23c20
|
@ -120,6 +120,10 @@ def train(flags):
|
|||
]
|
||||
frames, stats = 0, {k: 0 for k in stat_keys}
|
||||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
||||
if flags.unified_model:
|
||||
lock = threading.Lock()
|
||||
position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock}
|
||||
else:
|
||||
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||
|
||||
def sync_onnx_model(frames):
|
||||
|
@ -153,6 +157,12 @@ def train(flags):
|
|||
checkpoint_states = torch.load(
|
||||
checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu")
|
||||
)
|
||||
if flags.unified_model:
|
||||
learner_model.get_model('uni').load_state_dict(checkpoint_states["model_state_dict"]['uni'])
|
||||
optimizers['uni'].load_state_dict(checkpoint_states["optimizer_state_dict"]['uni'])
|
||||
if not flags.enable_onnx:
|
||||
actor_model.get_model('uni').load_state_dict(checkpoint_states["model_state_dict"]['uni'])
|
||||
else:
|
||||
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down']
|
||||
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
||||
|
@ -210,7 +220,7 @@ def train(flags):
|
|||
while frames < flags.total_frames:
|
||||
batch = get_batch(batch_queues, position, flags, local_lock)
|
||||
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
|
||||
optimizers[position], flags, position_lock)
|
||||
optimizers['uni'], flags, position_lock)
|
||||
with lock:
|
||||
for k in _stats:
|
||||
stats[k] = _stats[k]
|
||||
|
@ -248,7 +258,7 @@ def train(flags):
|
|||
}, checkpointpath + '.new')
|
||||
|
||||
# Save the weights for evaluation purpose
|
||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||
for position in ['uni'] if flags.unified_model else ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||
|
|
|
@ -104,6 +104,14 @@ def create_optimizers(flags, learner_model):
|
|||
"""
|
||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
optimizers = {}
|
||||
if flags.unified_model:
|
||||
position = 'uni'
|
||||
optimizer = RAdam(
|
||||
learner_model.parameters(position),
|
||||
lr=flags.learning_rate,
|
||||
eps=flags.epsilon)
|
||||
optimizers[position] = optimizer
|
||||
else:
|
||||
for position in positions:
|
||||
optimizer = RAdam(
|
||||
learner_model.parameters(position),
|
||||
|
|
Loading…
Reference in New Issue