309 lines
13 KiB
Python
309 lines
13 KiB
Python
import os
|
|
import threading
|
|
import time
|
|
import timeit
|
|
import pprint
|
|
from collections import deque
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch import multiprocessing as mp
|
|
from torch import nn
|
|
|
|
import douzero.dmc.models
|
|
import douzero.env.env
|
|
from .file_writer import FileWriter
|
|
from .models import Model, OldModel
|
|
from .utils import get_batch, log, create_env, create_optimizers, act
|
|
import psutil
|
|
import shutil
|
|
|
|
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']}
|
|
|
|
def compute_loss(logits, targets):
|
|
loss = ((logits.squeeze(-1) - targets)**2).mean()
|
|
return loss
|
|
|
|
def learn(position, actor_models, model, batch, optimizer, flags, lock):
|
|
"""Performs a learning (optimization) step."""
|
|
position_index = {"landlord": 31, "landlord_up": 32, 'landlord_front': 33, "landlord_down": 34}
|
|
print("Learn", position)
|
|
if flags.training_device != "cpu":
|
|
device = torch.device('cuda:'+str(flags.training_device))
|
|
else:
|
|
device = torch.device('cpu')
|
|
if flags.old_model:
|
|
obs_x_no_action = batch['obs_x_no_action'].to(device)
|
|
obs_action = batch['obs_x_batch'].to(device)
|
|
obs_x = torch.cat((obs_x_no_action, obs_action), dim=2).float()
|
|
obs_x = torch.flatten(obs_x, 0, 1)
|
|
else:
|
|
obs_x = batch["obs_x_batch"]
|
|
obs_x = torch.flatten(obs_x, 0, 1).to(device)
|
|
obs_z = torch.flatten(batch['obs_z'].to(device), 0, 1).float()
|
|
target = torch.flatten(batch['target'].to(device), 0, 1)
|
|
episode_returns = batch['episode_return'][batch['done'] & (batch["obs_type"] == position_index[position])]
|
|
if len(episode_returns) > 0:
|
|
mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device))
|
|
with lock:
|
|
learner_outputs = model(obs_z, obs_x)
|
|
loss = compute_loss(learner_outputs['values'], target)
|
|
stats = {
|
|
'mean_episode_return_'+position: torch.mean(torch.stack([_r for _r in mean_episode_return_buf[position]])).item(),
|
|
'loss_'+position: loss.item(),
|
|
}
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
|
|
optimizer.step()
|
|
|
|
if not flags.enable_onnx:
|
|
for actor_model in actor_models.values():
|
|
actor_model.get_model(position).load_state_dict(model.state_dict())
|
|
return stats
|
|
|
|
def train(flags):
|
|
"""
|
|
This is the main funtion for training. It will first
|
|
initilize everything, such as buffers, optimizers, etc.
|
|
Then it will start subprocesses as actors. Then, it will call
|
|
learning function with multiple threads.
|
|
"""
|
|
if not flags.actor_device_cpu or flags.training_device != 'cpu':
|
|
if not torch.cuda.is_available():
|
|
raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`")
|
|
plogger = FileWriter(
|
|
xpid=flags.xpid,
|
|
xp_args=flags.__dict__,
|
|
rootdir=flags.savedir,
|
|
)
|
|
checkpointpath = os.path.expandvars(
|
|
os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar')))
|
|
|
|
T = flags.unroll_length
|
|
B = flags.batch_size
|
|
|
|
if flags.actor_device_cpu:
|
|
device_iterator = ['cpu']
|
|
else:
|
|
device_iterator = [0, 'cpu'] #range(flags.num_actor_devices) #[0, 'cpu']
|
|
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
|
|
|
# Initialize actor models
|
|
models = {}
|
|
for device in device_iterator:
|
|
if flags.old_model:
|
|
model = OldModel(device="cpu", flags = flags)
|
|
else:
|
|
model = Model(device="cpu", flags = flags)
|
|
model.share_memory()
|
|
model.eval()
|
|
models[device] = model
|
|
|
|
# Initialize queues
|
|
actor_processes = []
|
|
ctx = mp.get_context('spawn')
|
|
batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue()}
|
|
onnx_frame = ctx.Value('d', -1)
|
|
|
|
# Learner model for training
|
|
if flags.old_model:
|
|
learner_model = OldModel(device=flags.training_device)
|
|
else:
|
|
learner_model = Model(device=flags.training_device)
|
|
|
|
# Create optimizers
|
|
optimizers = create_optimizers(flags, learner_model)
|
|
|
|
# Stat Keys
|
|
stat_keys = [
|
|
'mean_episode_return_landlord',
|
|
'loss_landlord',
|
|
'mean_episode_return_landlord_up',
|
|
'loss_landlord_up',
|
|
'mean_episode_return_landlord_front',
|
|
'loss_landlord_front',
|
|
'mean_episode_return_landlord_down',
|
|
'loss_landlord_down'
|
|
]
|
|
frames, stats = 0, {k: 0 for k in stat_keys}
|
|
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
|
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
|
|
|
def sync_onnx_model(frames):
|
|
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
|
|
if not os.path.exists(p_path):
|
|
os.makedirs(p_path)
|
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
|
if flags.enable_onnx:
|
|
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
|
onnx_params = learner_model.get_model(position)\
|
|
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
|
with position_locks[position]:
|
|
torch.onnx.export(
|
|
learner_model.get_model(position),
|
|
onnx_params['args'],
|
|
model_path,
|
|
export_params=True,
|
|
opset_version=10,
|
|
do_constant_folding=True,
|
|
input_names=onnx_params['input_names'],
|
|
output_names=onnx_params['output_names'],
|
|
dynamic_axes=onnx_params['dynamic_axes']
|
|
)
|
|
onnx_frame.value = frames
|
|
|
|
# Load models if any
|
|
if flags.load_model and os.path.exists(checkpointpath):
|
|
checkpoint_states = torch.load(
|
|
checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu")
|
|
)
|
|
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down']
|
|
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
|
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
|
if not flags.enable_onnx:
|
|
for device in device_iterator:
|
|
models[device].get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
|
stats = checkpoint_states["stats"]
|
|
|
|
frames = checkpoint_states["frames"]
|
|
position_frames = checkpoint_states["position_frames"]
|
|
sync_onnx_model(frames)
|
|
log.info(f"Resuming preempted job, current stats:\n{stats}")
|
|
|
|
# Starting actor processes
|
|
for device in device_iterator:
|
|
if device == 'cpu':
|
|
num_actors = flags.num_actors_cpu
|
|
else:
|
|
num_actors = flags.num_actors
|
|
for i in range(num_actors):
|
|
actor = mp.Process(
|
|
target=act,
|
|
args=(i, device, batch_queues, models[device], flags, onnx_frame))
|
|
actor.daemon = True
|
|
actor.start()
|
|
actor_processes.append({
|
|
'device': device,
|
|
'i': i,
|
|
'actor': actor
|
|
})
|
|
|
|
parent = psutil.Process()
|
|
parent.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
|
|
for child in parent.children():
|
|
child.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
|
|
|
|
def batch_and_learn(i, position, local_lock, position_lock, lock=threading.Lock()):
|
|
"""Thread target for the learning process."""
|
|
nonlocal frames, position_frames, stats
|
|
while frames < flags.total_frames:
|
|
batch = get_batch(batch_queues, position, flags, local_lock)
|
|
_stats = learn(position, models, learner_model.get_model(position), batch,
|
|
optimizers[position], flags, position_lock)
|
|
with lock:
|
|
for k in _stats:
|
|
stats[k] = _stats[k]
|
|
to_log = dict(frames=frames)
|
|
to_log.update({k: stats[k] for k in stat_keys})
|
|
plogger.log(to_log)
|
|
frames += T * B
|
|
position_frames[position] += T * B
|
|
|
|
|
|
threads = []
|
|
locks = {}
|
|
for device in device_iterator:
|
|
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
|
|
|
for i in range(flags.num_threads):
|
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
|
thread = threading.Thread(
|
|
target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,position,locks[device][position],position_locks[position]))
|
|
thread.setDaemon(True)
|
|
thread.start()
|
|
threads.append(thread)
|
|
|
|
def checkpoint(frames):
|
|
if flags.disable_checkpoint:
|
|
return
|
|
log.info('Saving checkpoint to %s', checkpointpath)
|
|
_models = learner_model.get_models()
|
|
torch.save({
|
|
'model_state_dict': {k: _models[k].state_dict() for k in _models}, # {{"general": _models["landlord"].state_dict()}
|
|
'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers}, # {"general": optimizers["landlord"].state_dict()}
|
|
"stats": stats,
|
|
'flags': vars(flags),
|
|
'frames': frames,
|
|
'position_frames': position_frames
|
|
}, checkpointpath + '.new')
|
|
|
|
# Save the weights for evaluation purpose
|
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
|
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
|
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
|
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
|
shutil.move(checkpointpath + '.new', checkpointpath)
|
|
|
|
fps_log = []
|
|
timer = timeit.default_timer
|
|
try:
|
|
last_checkpoint_time = timer() - flags.save_interval * 60
|
|
last_onnx_sync_time = timer()
|
|
while frames < flags.total_frames:
|
|
start_frames = frames
|
|
position_start_frames = {k: position_frames[k] for k in position_frames}
|
|
start_time = timer()
|
|
time.sleep(5)
|
|
|
|
if timer() - last_checkpoint_time > flags.save_interval * 60:
|
|
checkpoint(frames)
|
|
last_checkpoint_time = timer()
|
|
|
|
if timer() - last_onnx_sync_time > flags.onnx_sync_interval:
|
|
sync_onnx_model(frames)
|
|
last_onnx_sync_time = timer()
|
|
|
|
end_time = timer()
|
|
|
|
fps = (frames - start_frames) / (end_time - start_time)
|
|
fps_log.append(fps)
|
|
if len(fps_log) > 240:
|
|
fps_log = fps_log[1:]
|
|
fps_avg = np.mean(fps_log)
|
|
|
|
position_fps = {k:(position_frames[k]-position_start_frames[k])/(end_time-start_time) for k in position_frames}
|
|
log.info('After %i (L:%i U:%i F:%i D:%i) frames: @ %.1f fps (avg@ %.1f fps) (L:%.1f U:%.1f F:%.1f D:%.1f) Stats:\n%s',
|
|
frames,
|
|
position_frames['landlord'],
|
|
position_frames['landlord_up'],
|
|
position_frames['landlord_front'],
|
|
position_frames['landlord_down'],
|
|
fps,
|
|
fps_avg,
|
|
position_fps['landlord'],
|
|
position_fps['landlord_up'],
|
|
position_fps['landlord_front'],
|
|
position_fps['landlord_down'],
|
|
pprint.pformat(stats))
|
|
for proc in actor_processes:
|
|
if not proc['actor'].is_alive():
|
|
actor = mp.Process(
|
|
target=act,
|
|
args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame))
|
|
actor.daemon = True
|
|
actor.start()
|
|
proc['actor'] = actor
|
|
|
|
except KeyboardInterrupt:
|
|
checkpoint(frames)
|
|
return
|
|
else:
|
|
for thread in threads:
|
|
thread.join()
|
|
log.info('Learning finished after %d frames.', frames)
|
|
|
|
checkpoint(frames)
|
|
plogger.close()
|