修复onnx模式训练的BUG
This commit is contained in:
parent
94d64889a7
commit
a069ec2026
|
@ -11,6 +11,8 @@ parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp'
|
||||||
help='Use ADP or WP as reward (default: ADP)')
|
help='Use ADP or WP as reward (default: ADP)')
|
||||||
|
|
||||||
# Training settings
|
# Training settings
|
||||||
|
parser.add_argument('--onnx_sync_interval', default=5, type=int,
|
||||||
|
help='Time interval (in minutes) at which to sync the onnx model')
|
||||||
parser.add_argument('--actor_device_cpu', action='store_true',
|
parser.add_argument('--actor_device_cpu', action='store_true',
|
||||||
help='Use CPU as actor device')
|
help='Use CPU as actor device')
|
||||||
parser.add_argument('--gpu_devices', default='0', type=str,
|
parser.add_argument('--gpu_devices', default='0', type=str,
|
||||||
|
|
|
@ -130,6 +130,25 @@ def train(flags):
|
||||||
frames, stats = 0, {k: 0 for k in stat_keys}
|
frames, stats = 0, {k: 0 for k in stat_keys}
|
||||||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
||||||
|
|
||||||
|
def sync_onnx_model(frames):
|
||||||
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||||
|
if flags.enable_onnx and position:
|
||||||
|
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
|
||||||
|
onnx_params = learner_model.get_model(position)\
|
||||||
|
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
||||||
|
torch.onnx.export(
|
||||||
|
learner_model.get_model(position),
|
||||||
|
onnx_params['args'],
|
||||||
|
model_path,
|
||||||
|
export_params=True,
|
||||||
|
opset_version=10,
|
||||||
|
do_constant_folding=True,
|
||||||
|
input_names=onnx_params['input_names'],
|
||||||
|
output_names=onnx_params['output_names'],
|
||||||
|
dynamic_axes=onnx_params['dynamic_axes']
|
||||||
|
)
|
||||||
|
onnx_frame.value = frames
|
||||||
|
|
||||||
# Load models if any
|
# Load models if any
|
||||||
if flags.load_model and os.path.exists(checkpointpath):
|
if flags.load_model and os.path.exists(checkpointpath):
|
||||||
checkpoint_states = torch.load(
|
checkpoint_states = torch.load(
|
||||||
|
@ -145,6 +164,7 @@ def train(flags):
|
||||||
|
|
||||||
frames = checkpoint_states["frames"]
|
frames = checkpoint_states["frames"]
|
||||||
position_frames = checkpoint_states["position_frames"]
|
position_frames = checkpoint_states["position_frames"]
|
||||||
|
sync_onnx_model(frames)
|
||||||
log.info(f"Resuming preempted job, current stats:\n{stats}")
|
log.info(f"Resuming preempted job, current stats:\n{stats}")
|
||||||
|
|
||||||
# Starting actor processes
|
# Starting actor processes
|
||||||
|
@ -217,29 +237,13 @@ def train(flags):
|
||||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
||||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||||
if flags.enable_onnx and position:
|
|
||||||
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
|
|
||||||
onnx_params = learner_model.get_model(position)\
|
|
||||||
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
|
||||||
torch.onnx.export(
|
|
||||||
learner_model.get_model(position),
|
|
||||||
onnx_params['args'],
|
|
||||||
model_path,
|
|
||||||
export_params=True,
|
|
||||||
opset_version=10,
|
|
||||||
do_constant_folding=True,
|
|
||||||
input_names=onnx_params['input_names'],
|
|
||||||
output_names=onnx_params['output_names'],
|
|
||||||
dynamic_axes=onnx_params['dynamic_axes']
|
|
||||||
)
|
|
||||||
onnx_frame.value = frames
|
|
||||||
shutil.move(checkpointpath + '.new', checkpointpath)
|
shutil.move(checkpointpath + '.new', checkpointpath)
|
||||||
|
|
||||||
|
|
||||||
fps_log = []
|
fps_log = []
|
||||||
timer = timeit.default_timer
|
timer = timeit.default_timer
|
||||||
try:
|
try:
|
||||||
last_checkpoint_time = timer() - flags.save_interval * 60
|
last_checkpoint_time = timer() - flags.save_interval * 60
|
||||||
|
last_onnx_sync_time = timer()
|
||||||
while frames < flags.total_frames:
|
while frames < flags.total_frames:
|
||||||
start_frames = frames
|
start_frames = frames
|
||||||
position_start_frames = {k: position_frames[k] for k in position_frames}
|
position_start_frames = {k: position_frames[k] for k in position_frames}
|
||||||
|
@ -249,6 +253,11 @@ def train(flags):
|
||||||
if timer() - last_checkpoint_time > flags.save_interval * 60:
|
if timer() - last_checkpoint_time > flags.save_interval * 60:
|
||||||
checkpoint(frames)
|
checkpoint(frames)
|
||||||
last_checkpoint_time = timer()
|
last_checkpoint_time = timer()
|
||||||
|
|
||||||
|
if timer() - last_onnx_sync_time > flags.onnx_sync_interval * 60:
|
||||||
|
sync_onnx_model(frames)
|
||||||
|
last_onnx_sync_time = timer()
|
||||||
|
|
||||||
end_time = timer()
|
end_time = timer()
|
||||||
|
|
||||||
fps = (frames - start_frames) / (end_time - start_time)
|
fps = (frames - start_frames) / (end_time - start_time)
|
||||||
|
|
|
@ -4,4 +4,4 @@ gitdb2
|
||||||
rlcard
|
rlcard
|
||||||
psutil
|
psutil
|
||||||
onnx
|
onnx
|
||||||
onnxruntime-gpu
|
onnxruntime-gpu==1.7
|
||||||
|
|
Loading…
Reference in New Issue