diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index b791a38..eddbd26 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -166,12 +166,12 @@ def train(flags): for j in range(flags.num_actors_thread): for i in range(num_actors): infer_queues.append({ - 'input': ctx.Queue(), 'output': ctx.Queue() + 'input': ctx.Queue(maxsize=100), 'output': ctx.Queue(maxsize=100) }) infer_processes = [] for device in flags.infer_devices.split(','): - for i in range(flags.num_infer): + for i in range(flags.num_infer if device != 'cpu' else 1): infer = mp.Process( target=infer_logic, args=(i, device, infer_queues, actor_model, flags, onnx_frame))