修复CPU与GPU不能同时训练的BUG,退出时自动存档

This commit is contained in:
ZaneYork 2021-12-19 20:00:42 +08:00
parent 5b0fee04a8
commit 08cf434500
5 changed files with 10 additions and 9 deletions

View File

@ -1 +1 @@
python train.py --load_model --batch_size 4 --learning_rate 0.0003 --enable_onnx
python train.py --load_model --batch_size 4 --learning_rate 0.0003 --enable_onnx --num_actors 1

View File

@ -11,7 +11,7 @@ parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp'
help='Use ADP or WP as reward (default: ADP)')
# Training settings
parser.add_argument('--onnx_sync_interval', default=30, type=int,
parser.add_argument('--onnx_sync_interval', default=120, type=int,
help='Time interval (in seconds) at which to sync the onnx model')
parser.add_argument('--actor_device_cpu', action='store_true',
help='Use CPU as actor device')

View File

@ -87,7 +87,7 @@ def train(flags):
if flags.actor_device_cpu:
device_iterator = ['cpu']
else:
device_iterator = range(flags.num_actor_devices) #[0, 'cpu']
device_iterator = [0, 'cpu'] #range(flags.num_actor_devices) #[0, 'cpu']
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
# Initialize actor models
@ -282,6 +282,7 @@ def train(flags):
pprint.pformat(stats))
except KeyboardInterrupt:
checkpoint(frames)
return
else:
for thread in threads:

View File

@ -315,18 +315,19 @@ class Model:
'landlord_down': None
}
def set_onnx_model(self):
def set_onnx_model(self, device='cpu'):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
for position in positions:
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.savedir, self.flags.xpid, position))
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
if device == 'cpu':
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
else:
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
def get_onnx_params(self, position):
self.models[position].get_onnx_params(self.device)
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
if self.flags.enable_onnx and len(self.onnx_models) == 0:
self.set_onnx_model()
model = self.onnx_models[position]
if model is None:
model = self.models[position]

View File

@ -112,8 +112,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
# print("posi", position)
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
last_onnx_frame = onnx_frame.value
for p in positions:
model.set_onnx_model()
model.set_onnx_model(device)
while True:
if len(obs['legal_actions']) > 1: