修复onnx模式训练的BUG
This commit is contained in:
parent
94d64889a7
commit
a069ec2026
|
@ -11,6 +11,8 @@ parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp'
|
|||
help='Use ADP or WP as reward (default: ADP)')
|
||||
|
||||
# Training settings
|
||||
parser.add_argument('--onnx_sync_interval', default=5, type=int,
|
||||
help='Time interval (in minutes) at which to sync the onnx model')
|
||||
parser.add_argument('--actor_device_cpu', action='store_true',
|
||||
help='Use CPU as actor device')
|
||||
parser.add_argument('--gpu_devices', default='0', type=str,
|
||||
|
|
|
@ -130,6 +130,25 @@ def train(flags):
|
|||
frames, stats = 0, {k: 0 for k in stat_keys}
|
||||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
||||
|
||||
def sync_onnx_model(frames):
|
||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||
if flags.enable_onnx and position:
|
||||
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
|
||||
onnx_params = learner_model.get_model(position)\
|
||||
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
||||
torch.onnx.export(
|
||||
learner_model.get_model(position),
|
||||
onnx_params['args'],
|
||||
model_path,
|
||||
export_params=True,
|
||||
opset_version=10,
|
||||
do_constant_folding=True,
|
||||
input_names=onnx_params['input_names'],
|
||||
output_names=onnx_params['output_names'],
|
||||
dynamic_axes=onnx_params['dynamic_axes']
|
||||
)
|
||||
onnx_frame.value = frames
|
||||
|
||||
# Load models if any
|
||||
if flags.load_model and os.path.exists(checkpointpath):
|
||||
checkpoint_states = torch.load(
|
||||
|
@ -145,6 +164,7 @@ def train(flags):
|
|||
|
||||
frames = checkpoint_states["frames"]
|
||||
position_frames = checkpoint_states["position_frames"]
|
||||
sync_onnx_model(frames)
|
||||
log.info(f"Resuming preempted job, current stats:\n{stats}")
|
||||
|
||||
# Starting actor processes
|
||||
|
@ -217,29 +237,13 @@ def train(flags):
|
|||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||
if flags.enable_onnx and position:
|
||||
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
|
||||
onnx_params = learner_model.get_model(position)\
|
||||
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
||||
torch.onnx.export(
|
||||
learner_model.get_model(position),
|
||||
onnx_params['args'],
|
||||
model_path,
|
||||
export_params=True,
|
||||
opset_version=10,
|
||||
do_constant_folding=True,
|
||||
input_names=onnx_params['input_names'],
|
||||
output_names=onnx_params['output_names'],
|
||||
dynamic_axes=onnx_params['dynamic_axes']
|
||||
)
|
||||
onnx_frame.value = frames
|
||||
shutil.move(checkpointpath + '.new', checkpointpath)
|
||||
|
||||
|
||||
fps_log = []
|
||||
timer = timeit.default_timer
|
||||
try:
|
||||
last_checkpoint_time = timer() - flags.save_interval * 60
|
||||
last_onnx_sync_time = timer()
|
||||
while frames < flags.total_frames:
|
||||
start_frames = frames
|
||||
position_start_frames = {k: position_frames[k] for k in position_frames}
|
||||
|
@ -249,6 +253,11 @@ def train(flags):
|
|||
if timer() - last_checkpoint_time > flags.save_interval * 60:
|
||||
checkpoint(frames)
|
||||
last_checkpoint_time = timer()
|
||||
|
||||
if timer() - last_onnx_sync_time > flags.onnx_sync_interval * 60:
|
||||
sync_onnx_model(frames)
|
||||
last_onnx_sync_time = timer()
|
||||
|
||||
end_time = timer()
|
||||
|
||||
fps = (frames - start_frames) / (end_time - start_time)
|
||||
|
|
|
@ -4,4 +4,4 @@ gitdb2
|
|||
rlcard
|
||||
psutil
|
||||
onnx
|
||||
onnxruntime-gpu
|
||||
onnxruntime-gpu==1.7
|
||||
|
|
Loading…
Reference in New Issue