进程守护逻辑,添加onnx模型位置参数

This commit is contained in:
zhiyang7 2021-12-21 12:49:53 +08:00
parent eee9bce7dc
commit 66b432f52c
2 changed files with 19 additions and 2 deletions

View File

@ -37,6 +37,8 @@ parser.add_argument('--savedir', default='douzero_checkpoints',
help='Root dir where experiment data will be saved')
parser.add_argument('--enable_onnx', action='store_true',
help='Use onnx model for train')
parser.add_argument('--onnx_model_path', default='douzero_checkpoints',
help='Root dir where onnx temp model will be saved')
# Hyperparameters
parser.add_argument('--total_frames', default=100000000000, type=int,

View File

@ -104,7 +104,7 @@ def train(flags):
# Initialize queues
actor_processes = []
ctx = mp.get_context('spawn')
batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue()}
batch_queues = {"landlord": ctx.Queue(flags.unroll_length * 4), "landlord_up": ctx.Queue(flags.unroll_length * 4), 'landlord_front': ctx.Queue(flags.unroll_length * 4), "landlord_down": ctx.Queue(flags.unroll_length * 4)}
onnx_frame = ctx.Value('d', -1)
# Learner model for training
@ -131,6 +131,9 @@ def train(flags):
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
def sync_onnx_model(frames):
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
if not os.path.exists(p_path):
os.makedirs(p_path)
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
if flags.enable_onnx:
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
@ -179,7 +182,11 @@ def train(flags):
args=(i, device, batch_queues, models[device], flags, onnx_frame))
actor.daemon = True
actor.start()
actor_processes.append(actor)
actor_processes.append({
'device': device,
'i': i,
'actor': actor
})
parent = psutil.Process()
parent.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
@ -279,6 +286,14 @@ def train(flags):
position_fps['landlord_front'],
position_fps['landlord_down'],
pprint.pformat(stats))
for proc in actor_processes:
if not proc['actor'].is_alive():
actor = mp.Process(
target=act,
args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame))
actor.daemon = True
actor.start()
proc['actor'] = actor
except KeyboardInterrupt:
checkpoint(frames)