修复onnx模式训练的BUG

This commit is contained in:
ZaneYork 2021-12-19 18:49:08 +08:00
parent 94d64889a7
commit a069ec2026
3 changed files with 29 additions and 18 deletions

View File

@ -11,6 +11,8 @@ parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp'
help='Use ADP or WP as reward (default: ADP)') help='Use ADP or WP as reward (default: ADP)')
# Training settings # Training settings
parser.add_argument('--onnx_sync_interval', default=5, type=int,
help='Time interval (in minutes) at which to sync the onnx model')
parser.add_argument('--actor_device_cpu', action='store_true', parser.add_argument('--actor_device_cpu', action='store_true',
help='Use CPU as actor device') help='Use CPU as actor device')
parser.add_argument('--gpu_devices', default='0', type=str, parser.add_argument('--gpu_devices', default='0', type=str,

View File

@ -130,6 +130,25 @@ def train(flags):
frames, stats = 0, {k: 0 for k in stat_keys} frames, stats = 0, {k: 0 for k in stat_keys}
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0} position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
def sync_onnx_model(frames):
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
if flags.enable_onnx and position:
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
onnx_params = learner_model.get_model(position)\
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
torch.onnx.export(
learner_model.get_model(position),
onnx_params['args'],
model_path,
export_params=True,
opset_version=10,
do_constant_folding=True,
input_names=onnx_params['input_names'],
output_names=onnx_params['output_names'],
dynamic_axes=onnx_params['dynamic_axes']
)
onnx_frame.value = frames
# Load models if any # Load models if any
if flags.load_model and os.path.exists(checkpointpath): if flags.load_model and os.path.exists(checkpointpath):
checkpoint_states = torch.load( checkpoint_states = torch.load(
@ -145,6 +164,7 @@ def train(flags):
frames = checkpoint_states["frames"] frames = checkpoint_states["frames"]
position_frames = checkpoint_states["position_frames"] position_frames = checkpoint_states["position_frames"]
sync_onnx_model(frames)
log.info(f"Resuming preempted job, current stats:\n{stats}") log.info(f"Resuming preempted job, current stats:\n{stats}")
# Starting actor processes # Starting actor processes
@ -217,29 +237,13 @@ def train(flags):
model_weights_dir = os.path.expandvars(os.path.expanduser( model_weights_dir = os.path.expandvars(os.path.expanduser(
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt'))) '%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
if flags.enable_onnx and position:
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
onnx_params = learner_model.get_model(position)\
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
torch.onnx.export(
learner_model.get_model(position),
onnx_params['args'],
model_path,
export_params=True,
opset_version=10,
do_constant_folding=True,
input_names=onnx_params['input_names'],
output_names=onnx_params['output_names'],
dynamic_axes=onnx_params['dynamic_axes']
)
onnx_frame.value = frames
shutil.move(checkpointpath + '.new', checkpointpath) shutil.move(checkpointpath + '.new', checkpointpath)
fps_log = [] fps_log = []
timer = timeit.default_timer timer = timeit.default_timer
try: try:
last_checkpoint_time = timer() - flags.save_interval * 60 last_checkpoint_time = timer() - flags.save_interval * 60
last_onnx_sync_time = timer()
while frames < flags.total_frames: while frames < flags.total_frames:
start_frames = frames start_frames = frames
position_start_frames = {k: position_frames[k] for k in position_frames} position_start_frames = {k: position_frames[k] for k in position_frames}
@ -249,6 +253,11 @@ def train(flags):
if timer() - last_checkpoint_time > flags.save_interval * 60: if timer() - last_checkpoint_time > flags.save_interval * 60:
checkpoint(frames) checkpoint(frames)
last_checkpoint_time = timer() last_checkpoint_time = timer()
if timer() - last_onnx_sync_time > flags.onnx_sync_interval * 60:
sync_onnx_model(frames)
last_onnx_sync_time = timer()
end_time = timer() end_time = timer()
fps = (frames - start_frames) / (end_time - start_time) fps = (frames - start_frames) / (end_time - start_time)

View File

@ -4,4 +4,4 @@ gitdb2
rlcard rlcard
psutil psutil
onnx onnx
onnxruntime-gpu onnxruntime-gpu==1.7