修复CPU与GPU不能同时训练的BUG,退出时自动存档
This commit is contained in:
parent
5b0fee04a8
commit
08cf434500
|
@ -1 +1 @@
|
||||||
python train.py --load_model --batch_size 4 --learning_rate 0.0003 --enable_onnx
|
python train.py --load_model --batch_size 4 --learning_rate 0.0003 --enable_onnx --num_actors 1
|
|
@ -11,7 +11,7 @@ parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp'
|
||||||
help='Use ADP or WP as reward (default: ADP)')
|
help='Use ADP or WP as reward (default: ADP)')
|
||||||
|
|
||||||
# Training settings
|
# Training settings
|
||||||
parser.add_argument('--onnx_sync_interval', default=30, type=int,
|
parser.add_argument('--onnx_sync_interval', default=120, type=int,
|
||||||
help='Time interval (in seconds) at which to sync the onnx model')
|
help='Time interval (in seconds) at which to sync the onnx model')
|
||||||
parser.add_argument('--actor_device_cpu', action='store_true',
|
parser.add_argument('--actor_device_cpu', action='store_true',
|
||||||
help='Use CPU as actor device')
|
help='Use CPU as actor device')
|
||||||
|
|
|
@ -87,7 +87,7 @@ def train(flags):
|
||||||
if flags.actor_device_cpu:
|
if flags.actor_device_cpu:
|
||||||
device_iterator = ['cpu']
|
device_iterator = ['cpu']
|
||||||
else:
|
else:
|
||||||
device_iterator = range(flags.num_actor_devices) #[0, 'cpu']
|
device_iterator = [0, 'cpu'] #range(flags.num_actor_devices) #[0, 'cpu']
|
||||||
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
||||||
|
|
||||||
# Initialize actor models
|
# Initialize actor models
|
||||||
|
@ -282,6 +282,7 @@ def train(flags):
|
||||||
pprint.pformat(stats))
|
pprint.pformat(stats))
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
checkpoint(frames)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
|
|
|
@ -315,18 +315,19 @@ class Model:
|
||||||
'landlord_down': None
|
'landlord_down': None
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_onnx_model(self):
|
def set_onnx_model(self, device='cpu'):
|
||||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
for position in positions:
|
for position in positions:
|
||||||
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.savedir, self.flags.xpid, position))
|
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.savedir, self.flags.xpid, position))
|
||||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
if device == 'cpu':
|
||||||
|
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
||||||
|
else:
|
||||||
|
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
|
||||||
|
|
||||||
def get_onnx_params(self, position):
|
def get_onnx_params(self, position):
|
||||||
self.models[position].get_onnx_params(self.device)
|
self.models[position].get_onnx_params(self.device)
|
||||||
|
|
||||||
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
||||||
if self.flags.enable_onnx and len(self.onnx_models) == 0:
|
|
||||||
self.set_onnx_model()
|
|
||||||
model = self.onnx_models[position]
|
model = self.onnx_models[position]
|
||||||
if model is None:
|
if model is None:
|
||||||
model = self.models[position]
|
model = self.models[position]
|
||||||
|
|
|
@ -112,8 +112,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
||||||
# print("posi", position)
|
# print("posi", position)
|
||||||
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
|
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
|
||||||
last_onnx_frame = onnx_frame.value
|
last_onnx_frame = onnx_frame.value
|
||||||
for p in positions:
|
model.set_onnx_model(device)
|
||||||
model.set_onnx_model()
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if len(obs['legal_actions']) > 1:
|
if len(obs['legal_actions']) > 1:
|
||||||
|
|
Loading…
Reference in New Issue