进程守护逻辑,添加onnx模型位置参数
This commit is contained in:
parent
eee9bce7dc
commit
66b432f52c
|
@ -37,6 +37,8 @@ parser.add_argument('--savedir', default='douzero_checkpoints',
|
|||
help='Root dir where experiment data will be saved')
|
||||
parser.add_argument('--enable_onnx', action='store_true',
|
||||
help='Use onnx model for train')
|
||||
parser.add_argument('--onnx_model_path', default='douzero_checkpoints',
|
||||
help='Root dir where onnx temp model will be saved')
|
||||
|
||||
# Hyperparameters
|
||||
parser.add_argument('--total_frames', default=100000000000, type=int,
|
||||
|
|
|
@ -104,7 +104,7 @@ def train(flags):
|
|||
# Initialize queues
|
||||
actor_processes = []
|
||||
ctx = mp.get_context('spawn')
|
||||
batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue()}
|
||||
batch_queues = {"landlord": ctx.Queue(flags.unroll_length * 4), "landlord_up": ctx.Queue(flags.unroll_length * 4), 'landlord_front': ctx.Queue(flags.unroll_length * 4), "landlord_down": ctx.Queue(flags.unroll_length * 4)}
|
||||
onnx_frame = ctx.Value('d', -1)
|
||||
|
||||
# Learner model for training
|
||||
|
@ -131,6 +131,9 @@ def train(flags):
|
|||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
||||
|
||||
def sync_onnx_model(frames):
|
||||
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
|
||||
if not os.path.exists(p_path):
|
||||
os.makedirs(p_path)
|
||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||
if flags.enable_onnx:
|
||||
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
||||
|
@ -179,7 +182,11 @@ def train(flags):
|
|||
args=(i, device, batch_queues, models[device], flags, onnx_frame))
|
||||
actor.daemon = True
|
||||
actor.start()
|
||||
actor_processes.append(actor)
|
||||
actor_processes.append({
|
||||
'device': device,
|
||||
'i': i,
|
||||
'actor': actor
|
||||
})
|
||||
|
||||
parent = psutil.Process()
|
||||
parent.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
|
||||
|
@ -279,6 +286,14 @@ def train(flags):
|
|||
position_fps['landlord_front'],
|
||||
position_fps['landlord_down'],
|
||||
pprint.pformat(stats))
|
||||
for proc in actor_processes:
|
||||
if not proc['actor'].is_alive():
|
||||
actor = mp.Process(
|
||||
target=act,
|
||||
args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame))
|
||||
actor.daemon = True
|
||||
actor.start()
|
||||
proc['actor'] = actor
|
||||
|
||||
except KeyboardInterrupt:
|
||||
checkpoint(frames)
|
||||
|
|
Loading…
Reference in New Issue