修复bug

This commit is contained in:
zhiyang7 2022-01-04 21:12:20 +08:00
parent 8dd95e3e59
commit 0f8cd23c20
2 changed files with 26 additions and 8 deletions

View File

@ -120,6 +120,10 @@ def train(flags):
]
frames, stats = 0, {k: 0 for k in stat_keys}
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
if flags.unified_model:
lock = threading.Lock()
position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock}
else:
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
def sync_onnx_model(frames):
@ -153,6 +157,12 @@ def train(flags):
checkpoint_states = torch.load(
checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu")
)
if flags.unified_model:
learner_model.get_model('uni').load_state_dict(checkpoint_states["model_state_dict"]['uni'])
optimizers['uni'].load_state_dict(checkpoint_states["optimizer_state_dict"]['uni'])
if not flags.enable_onnx:
actor_model.get_model('uni').load_state_dict(checkpoint_states["model_state_dict"]['uni'])
else:
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down']
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
@ -210,7 +220,7 @@ def train(flags):
while frames < flags.total_frames:
batch = get_batch(batch_queues, position, flags, local_lock)
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
optimizers[position], flags, position_lock)
optimizers['uni'], flags, position_lock)
with lock:
for k in _stats:
stats[k] = _stats[k]
@ -248,7 +258,7 @@ def train(flags):
}, checkpointpath + '.new')
# Save the weights for evaluation purpose
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
for position in ['uni'] if flags.unified_model else ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
model_weights_dir = os.path.expandvars(os.path.expanduser(
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)

View File

@ -104,6 +104,14 @@ def create_optimizers(flags, learner_model):
"""
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
optimizers = {}
if flags.unified_model:
position = 'uni'
optimizer = RAdam(
learner_model.parameters(position),
lr=flags.learning_rate,
eps=flags.epsilon)
optimizers[position] = optimizer
else:
for position in positions:
optimizer = RAdam(
learner_model.parameters(position),