diff --git a/cuda_train.cmd b/cuda_train.cmd index 7d471d5..dddaab2 100644 --- a/cuda_train.cmd +++ b/cuda_train.cmd @@ -1 +1 @@ -python train.py --load_model --batch_size 4 --learning_rate 0.0003 --enable_onnx \ No newline at end of file +python train.py --load_model --batch_size 4 --learning_rate 0.0003 --enable_onnx --num_actors 1 \ No newline at end of file diff --git a/douzero/dmc/arguments.py b/douzero/dmc/arguments.py index 68244c3..82ba147 100644 --- a/douzero/dmc/arguments.py +++ b/douzero/dmc/arguments.py @@ -11,7 +11,7 @@ parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp' help='Use ADP or WP as reward (default: ADP)') # Training settings -parser.add_argument('--onnx_sync_interval', default=30, type=int, +parser.add_argument('--onnx_sync_interval', default=120, type=int, help='Time interval (in seconds) at which to sync the onnx model') parser.add_argument('--actor_device_cpu', action='store_true', help='Use CPU as actor device') diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 242aecf..a7d3c21 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -87,7 +87,7 @@ def train(flags): if flags.actor_device_cpu: device_iterator = ['cpu'] else: - device_iterator = range(flags.num_actor_devices) #[0, 'cpu'] + device_iterator = [0, 'cpu'] #range(flags.num_actor_devices) #[0, 'cpu'] assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices' # Initialize actor models @@ -282,6 +282,7 @@ def train(flags): pprint.pformat(stats)) except KeyboardInterrupt: + checkpoint(frames) return else: for thread in threads: diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index 6efde17..3bead7c 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -315,18 +315,19 @@ class Model: 'landlord_down': None } - def set_onnx_model(self): + def set_onnx_model(self, device='cpu'): positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] for position in positions: model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.savedir, self.flags.xpid, position)) - self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + if device == 'cpu': + self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider']) + else: + self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider']) def get_onnx_params(self, position): self.models[position].get_onnx_params(self.device) def forward(self, position, z, x, return_value=False, flags=None, debug=False): - if self.flags.enable_onnx and len(self.onnx_models) == 0: - self.set_onnx_model() model = self.onnx_models[position] if model is None: model = self.models[position] diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index 11a372a..5bcf662 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -112,8 +112,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame): # print("posi", position) if flags.enable_onnx and onnx_frame.value != last_onnx_frame: last_onnx_frame = onnx_frame.value - for p in positions: - model.set_onnx_model() + model.set_onnx_model(device) while True: if len(obs['legal_actions']) > 1: