修复CPU与GPU不能同时训练的BUG,退出时自动存档
This commit is contained in:
parent
5b0fee04a8
commit
08cf434500
|
@ -1 +1 @@
|
|||
python train.py --load_model --batch_size 4 --learning_rate 0.0003 --enable_onnx
|
||||
python train.py --load_model --batch_size 4 --learning_rate 0.0003 --enable_onnx --num_actors 1
|
|
@ -11,7 +11,7 @@ parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp'
|
|||
help='Use ADP or WP as reward (default: ADP)')
|
||||
|
||||
# Training settings
|
||||
parser.add_argument('--onnx_sync_interval', default=30, type=int,
|
||||
parser.add_argument('--onnx_sync_interval', default=120, type=int,
|
||||
help='Time interval (in seconds) at which to sync the onnx model')
|
||||
parser.add_argument('--actor_device_cpu', action='store_true',
|
||||
help='Use CPU as actor device')
|
||||
|
|
|
@ -87,7 +87,7 @@ def train(flags):
|
|||
if flags.actor_device_cpu:
|
||||
device_iterator = ['cpu']
|
||||
else:
|
||||
device_iterator = range(flags.num_actor_devices) #[0, 'cpu']
|
||||
device_iterator = [0, 'cpu'] #range(flags.num_actor_devices) #[0, 'cpu']
|
||||
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
||||
|
||||
# Initialize actor models
|
||||
|
@ -282,6 +282,7 @@ def train(flags):
|
|||
pprint.pformat(stats))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
checkpoint(frames)
|
||||
return
|
||||
else:
|
||||
for thread in threads:
|
||||
|
|
|
@ -315,18 +315,19 @@ class Model:
|
|||
'landlord_down': None
|
||||
}
|
||||
|
||||
def set_onnx_model(self):
|
||||
def set_onnx_model(self, device='cpu'):
|
||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
for position in positions:
|
||||
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.savedir, self.flags.xpid, position))
|
||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
if device == 'cpu':
|
||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
||||
else:
|
||||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
|
||||
|
||||
def get_onnx_params(self, position):
|
||||
self.models[position].get_onnx_params(self.device)
|
||||
|
||||
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
||||
if self.flags.enable_onnx and len(self.onnx_models) == 0:
|
||||
self.set_onnx_model()
|
||||
model = self.onnx_models[position]
|
||||
if model is None:
|
||||
model = self.models[position]
|
||||
|
|
|
@ -112,8 +112,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
|||
# print("posi", position)
|
||||
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
|
||||
last_onnx_frame = onnx_frame.value
|
||||
for p in positions:
|
||||
model.set_onnx_model()
|
||||
model.set_onnx_model(device)
|
||||
|
||||
while True:
|
||||
if len(obs['legal_actions']) > 1:
|
||||
|
|
Loading…
Reference in New Issue