DouZero_For_HLDDZ_FullAuto/douzero/dmc/dmc.py

232 lines
9.2 KiB
Python
Raw Permalink Normal View History

2021-07-28 19:47:43 +08:00
import os
import threading
import time
import timeit
import pprint
from collections import deque
import torch
from torch import multiprocessing as mp
from torch import nn
from .file_writer import FileWriter
from .models import Model
from .utils import get_batch, log, create_env, create_buffers, create_optimizers, act
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_down']}
def compute_loss(logits, targets):
loss = ((logits.squeeze(-1) - targets)**2).mean()
return loss
def learn(position,
actor_models,
model,
batch,
optimizer,
flags,
lock):
"""Performs a learning (optimization) step."""
device = torch.device('cuda:'+str(flags.training_device))
obs_x_no_action = batch['obs_x_no_action'].to(device)
obs_action = batch['obs_action'].to(device)
obs_x = torch.cat((obs_x_no_action, obs_action), dim=2).float()
obs_x = torch.flatten(obs_x, 0, 1)
obs_z = torch.flatten(batch['obs_z'].to(device), 0, 1).float()
target = torch.flatten(batch['target'].to(device), 0, 1)
episode_returns = batch['episode_return'][batch['done']]
mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device))
with lock:
learner_outputs = model(obs_z, obs_x, return_value=True)
loss = compute_loss(learner_outputs['values'], target)
stats = {
'mean_episode_return_'+position: torch.mean(torch.stack([_r for _r in mean_episode_return_buf[position]])).item(),
'loss_'+position: loss.item(),
}
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
optimizer.step()
for actor_model in actor_models:
actor_model.get_model(position).load_state_dict(model.state_dict())
return stats
def train(flags):
"""
This is the main funtion for training. It will first
initilize everything, such as buffers, optimizers, etc.
Then it will start subprocesses as actors. Then, it will call
learning function with multiple threads.
"""
plogger = FileWriter(
xpid=flags.xpid,
xp_args=flags.__dict__,
rootdir=flags.savedir,
)
checkpointpath = os.path.expandvars(
os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar')))
T = flags.unroll_length
B = flags.batch_size
# Initialize actor models
models = []
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
for device in range(flags.num_actor_devices):
model = Model(device=device)
model.share_memory()
model.eval()
models.append(model)
# Initialize buffers
buffers = create_buffers(flags)
# Initialize queues
actor_processes = []
ctx = mp.get_context('spawn')
free_queue = []
full_queue = []
for device in range(flags.num_actor_devices):
_free_queue = {'landlord': ctx.SimpleQueue(), 'landlord_up': ctx.SimpleQueue(), 'landlord_down': ctx.SimpleQueue()}
_full_queue = {'landlord': ctx.SimpleQueue(), 'landlord_up': ctx.SimpleQueue(), 'landlord_down': ctx.SimpleQueue()}
free_queue.append(_free_queue)
full_queue.append(_full_queue)
# Learner model for training
learner_model = Model(device=flags.training_device)
# Create optimizers
optimizers = create_optimizers(flags, learner_model)
# Stat Keys
stat_keys = [
'mean_episode_return_landlord',
'loss_landlord',
'mean_episode_return_landlord_up',
'loss_landlord_up',
'mean_episode_return_landlord_down',
'loss_landlord_down',
]
frames, stats = 0, {k: 0 for k in stat_keys}
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_down':0}
# Load models if any
if flags.load_model and os.path.exists(checkpointpath):
checkpoint_states = torch.load(
checkpointpath, map_location="cuda:"+str(flags.training_device)
)
for k in ['landlord', 'landlord_up', 'landlord_down']:
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
for device in range(flags.num_actor_devices):
models[device].get_model(k).load_state_dict(learner_model.get_model(k).state_dict())
stats = checkpoint_states["stats"]
frames = checkpoint_states["frames"]
position_frames = checkpoint_states["position_frames"]
log.info(f"Resuming preempted job, current stats:\n{stats}")
# Starting actor processes
for device in range(flags.num_actor_devices):
num_actors = flags.num_actors
for i in range(flags.num_actors):
actor = ctx.Process(
target=act,
args=(i, device, free_queue[device], full_queue[device], models[device], buffers[device], flags))
actor.start()
actor_processes.append(actor)
def batch_and_learn(i, device, position, local_lock, position_lock, lock=threading.Lock()):
"""Thread target for the learning process."""
nonlocal frames, position_frames, stats
while frames < flags.total_frames:
batch = get_batch(free_queue[device][position], full_queue[device][position], buffers[device][position], flags, local_lock)
_stats = learn(position, models, learner_model.get_model(position), batch,
optimizers[position], flags, position_lock)
with lock:
for k in _stats:
stats[k] = _stats[k]
to_log = dict(frames=frames)
to_log.update({k: stats[k] for k in stat_keys})
plogger.log(to_log)
frames += T * B
position_frames[position] += T * B
for device in range(flags.num_actor_devices):
for m in range(flags.num_buffers):
free_queue[device]['landlord'].put(m)
free_queue[device]['landlord_up'].put(m)
free_queue[device]['landlord_down'].put(m)
threads = []
locks = [{'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_down': threading.Lock()} for _ in range(flags.num_actor_devices)]
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_down': threading.Lock()}
for device in range(flags.num_actor_devices):
for i in range(flags.num_threads):
for position in ['landlord', 'landlord_up', 'landlord_down']:
thread = threading.Thread(
target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,device,position,locks[device][position],position_locks[position]))
thread.start()
threads.append(thread)
def checkpoint(frames):
if flags.disable_checkpoint:
return
log.info('Saving checkpoint to %s', checkpointpath)
_models = learner_model.get_models()
torch.save({
'model_state_dict': {k: _models[k].state_dict() for k in _models},
'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers},
"stats": stats,
'flags': vars(flags),
'frames': frames,
'position_frames': position_frames
}, checkpointpath)
# Save the weights for evaluation purpose
for position in ['landlord', 'landlord_up', 'landlord_down']:
model_weights_dir = os.path.expandvars(os.path.expanduser(
'%s/%s/%s' % (flags.savedir, flags.xpid, position+'_weights_'+str(frames)+'.ckpt')))
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
timer = timeit.default_timer
try:
last_checkpoint_time = timer() - flags.save_interval * 60
while frames < flags.total_frames:
start_frames = frames
position_start_frames = {k: position_frames[k] for k in position_frames}
start_time = timer()
time.sleep(5)
if timer() - last_checkpoint_time > flags.save_interval * 60:
checkpoint(frames)
last_checkpoint_time = timer()
end_time = timer()
fps = (frames - start_frames) / (end_time - start_time)
position_fps = {k:(position_frames[k]-position_start_frames[k])/(end_time-start_time) for k in position_frames}
log.info('After %i (L:%i U:%i D:%i) frames: @ %.1f fps (L:%.1f U:%.1f D:%.1f) Stats:\n%s',
frames,
position_frames['landlord'],
position_frames['landlord_up'],
position_frames['landlord_down'],
fps,
position_fps['landlord'],
position_fps['landlord_up'],
position_fps['landlord_down'],
pprint.pformat(stats))
except KeyboardInterrupt:
return
else:
for thread in threads:
thread.join()
log.info('Learning finished after %d frames.', frames)
checkpoint(frames)
plogger.close()