调整参数

This commit is contained in:
ZaneYork 2021-12-19 19:09:27 +08:00
parent a069ec2026
commit 5b0fee04a8
3 changed files with 4 additions and 4 deletions

View File

@ -1 +1 @@
python train.py --load_model --batch_size 8 --learning_rate 0.0003 --enable_onnx python train.py --load_model --batch_size 4 --learning_rate 0.0003 --enable_onnx

View File

@ -11,8 +11,8 @@ parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp'
help='Use ADP or WP as reward (default: ADP)') help='Use ADP or WP as reward (default: ADP)')
# Training settings # Training settings
parser.add_argument('--onnx_sync_interval', default=5, type=int, parser.add_argument('--onnx_sync_interval', default=30, type=int,
help='Time interval (in minutes) at which to sync the onnx model') help='Time interval (in seconds) at which to sync the onnx model')
parser.add_argument('--actor_device_cpu', action='store_true', parser.add_argument('--actor_device_cpu', action='store_true',
help='Use CPU as actor device') help='Use CPU as actor device')
parser.add_argument('--gpu_devices', default='0', type=str, parser.add_argument('--gpu_devices', default='0', type=str,

View File

@ -254,7 +254,7 @@ def train(flags):
checkpoint(frames) checkpoint(frames)
last_checkpoint_time = timer() last_checkpoint_time = timer()
if timer() - last_onnx_sync_time > flags.onnx_sync_interval * 60: if timer() - last_onnx_sync_time > flags.onnx_sync_interval:
sync_onnx_model(frames) sync_onnx_model(frames)
last_onnx_sync_time = timer() last_onnx_sync_time = timer()