253 lines
10 KiB
Python
253 lines
10 KiB
Python
|
import os
|
||
|
import threading
|
||
|
import time
|
||
|
import timeit
|
||
|
import pprint
|
||
|
from collections import deque
|
||
|
import numpy as np
|
||
|
|
||
|
import torch
|
||
|
from torch import multiprocessing as mp
|
||
|
from torch import nn
|
||
|
|
||
|
import douzero.dmc.models
|
||
|
import douzero.env.env
|
||
|
from .file_writer import FileWriter
|
||
|
from .models import Model, OldModel
|
||
|
from .utils import get_batch, log, create_env, create_optimizers, act
|
||
|
|
||
|
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_down', 'bidding']}
|
||
|
|
||
|
def compute_loss(logits, targets):
|
||
|
loss = ((logits.squeeze(-1) - targets)**2).mean()
|
||
|
return loss
|
||
|
|
||
|
def compute_loss_for_bid(outputs, reward):
|
||
|
pass
|
||
|
|
||
|
def learn(position, actor_models, model, batch, optimizer, flags, lock):
|
||
|
"""Performs a learning (optimization) step."""
|
||
|
position_index = {"landlord": 31, "landlord_up": 32, "landlord_down": 33}
|
||
|
print("Learn", position)
|
||
|
if flags.training_device != "cpu":
|
||
|
device = torch.device('cuda:'+str(flags.training_device))
|
||
|
else:
|
||
|
device = torch.device('cpu')
|
||
|
obs_x = batch["obs_x_batch"]
|
||
|
obs_x = torch.flatten(obs_x, 0, 1).to(device)
|
||
|
obs_z = torch.flatten(batch['obs_z'].to(device), 0, 1).float()
|
||
|
target = torch.flatten(batch['target'].to(device), 0, 1)
|
||
|
if position != "bidding":
|
||
|
episode_returns = batch['episode_return'][batch['done'] & (batch["obs_type"] == position_index[position])]
|
||
|
else:
|
||
|
episode_returns = batch['episode_return'][batch['done'] & ((batch["obs_type"] == 41) | (batch["obs_type"] == 42) | (batch["obs_type"] == 43))]
|
||
|
if len(episode_returns) > 0:
|
||
|
mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device))
|
||
|
with lock:
|
||
|
learner_outputs = model(obs_z, obs_x, return_value=True)
|
||
|
if position == "bidding":
|
||
|
pass
|
||
|
else:
|
||
|
loss = compute_loss(learner_outputs['values'], target)
|
||
|
stats = {
|
||
|
'mean_episode_return_'+position: torch.mean(torch.stack([_r for _r in mean_episode_return_buf[position]])).item(),
|
||
|
'loss_'+position: loss.item(),
|
||
|
}
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
loss.backward()
|
||
|
nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
|
||
|
optimizer.step()
|
||
|
|
||
|
for actor_model in actor_models.values():
|
||
|
actor_model.get_model(position).load_state_dict(model.state_dict())
|
||
|
return stats
|
||
|
|
||
|
def train(flags):
|
||
|
"""
|
||
|
This is the main funtion for training. It will first
|
||
|
initilize everything, such as buffers, optimizers, etc.
|
||
|
Then it will start subprocesses as actors. Then, it will call
|
||
|
learning function with multiple threads.
|
||
|
"""
|
||
|
if not flags.actor_device_cpu or flags.training_device != 'cpu':
|
||
|
if not torch.cuda.is_available():
|
||
|
raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`")
|
||
|
plogger = FileWriter(
|
||
|
xpid=flags.xpid,
|
||
|
xp_args=flags.__dict__,
|
||
|
rootdir=flags.savedir,
|
||
|
)
|
||
|
checkpointpath = os.path.expandvars(
|
||
|
os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar')))
|
||
|
|
||
|
T = flags.unroll_length
|
||
|
B = flags.batch_size
|
||
|
|
||
|
if flags.actor_device_cpu:
|
||
|
device_iterator = ['cpu']
|
||
|
else:
|
||
|
device_iterator = range(flags.num_actor_devices)
|
||
|
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
||
|
|
||
|
# Initialize actor models
|
||
|
models = {}
|
||
|
for device in device_iterator:
|
||
|
model = Model(device="cpu")
|
||
|
model.share_memory()
|
||
|
model.eval()
|
||
|
models[device] = model
|
||
|
|
||
|
# Initialize queues
|
||
|
actor_processes = []
|
||
|
ctx = mp.get_context('spawn')
|
||
|
batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue(), "bidding": ctx.SimpleQueue()}
|
||
|
|
||
|
# Learner model for training
|
||
|
learner_model = Model(device=flags.training_device)
|
||
|
|
||
|
# Create optimizers
|
||
|
optimizers = create_optimizers(flags, learner_model)
|
||
|
|
||
|
# Stat Keys
|
||
|
stat_keys = [
|
||
|
'mean_episode_return_landlord',
|
||
|
'loss_landlord',
|
||
|
'mean_episode_return_landlord_up',
|
||
|
'loss_landlord_up',
|
||
|
'mean_episode_return_landlord_down',
|
||
|
'loss_landlord_down',
|
||
|
'mean_episode_return_bidding',
|
||
|
'loss_bidding',
|
||
|
]
|
||
|
frames, stats = 0, {k: 0 for k in stat_keys}
|
||
|
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_down':0, 'bidding': 0}
|
||
|
|
||
|
# Load models if any
|
||
|
if flags.load_model and os.path.exists(checkpointpath):
|
||
|
checkpoint_states = torch.load(
|
||
|
checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu")
|
||
|
)
|
||
|
for k in ['landlord', 'landlord_up', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_down']
|
||
|
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||
|
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
||
|
for device in device_iterator:
|
||
|
models[device].get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||
|
stats = checkpoint_states["stats"]
|
||
|
|
||
|
if not 'mean_episode_return_bidding' in stats:
|
||
|
stats.update({"mean_episode_return_bidding": 0})
|
||
|
if not 'loss_bidding' in stats:
|
||
|
stats.update({"loss_bidding": 0})
|
||
|
frames = checkpoint_states["frames"]
|
||
|
position_frames = checkpoint_states["position_frames"]
|
||
|
if not "bidding" in position_frames:
|
||
|
position_frames.update({"bidding": 0})
|
||
|
log.info(f"Resuming preempted job, current stats:\n{stats}")
|
||
|
|
||
|
# Starting actor processes
|
||
|
for device in device_iterator:
|
||
|
num_actors = flags.num_actors
|
||
|
for i in range(flags.num_actors):
|
||
|
actor = mp.Process(
|
||
|
target=act,
|
||
|
args=(i, device, batch_queues, models[device], flags))
|
||
|
# actor.setDaemon(True)
|
||
|
actor.start()
|
||
|
actor_processes.append(actor)
|
||
|
|
||
|
def batch_and_learn(i, device, position, local_lock, position_lock, lock=threading.Lock()):
|
||
|
"""Thread target for the learning process."""
|
||
|
nonlocal frames, position_frames, stats
|
||
|
while frames < flags.total_frames:
|
||
|
batch = get_batch(batch_queues, position, flags, local_lock)
|
||
|
_stats = learn(position, models, learner_model.get_model(position), batch,
|
||
|
optimizers[position], flags, position_lock)
|
||
|
with lock:
|
||
|
for k in _stats:
|
||
|
stats[k] = _stats[k]
|
||
|
to_log = dict(frames=frames)
|
||
|
to_log.update({k: stats[k] for k in stat_keys})
|
||
|
plogger.log(to_log)
|
||
|
frames += T * B
|
||
|
position_frames[position] += T * B
|
||
|
|
||
|
|
||
|
threads = []
|
||
|
locks = {}
|
||
|
for device in device_iterator:
|
||
|
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_down': threading.Lock(), 'bidding': threading.Lock()}
|
||
|
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_down': threading.Lock(), 'bidding': threading.Lock()}
|
||
|
|
||
|
for device in device_iterator:
|
||
|
for i in range(flags.num_threads):
|
||
|
for position in ['landlord', 'landlord_up', 'landlord_down', 'bidding']:
|
||
|
thread = threading.Thread(
|
||
|
target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,device,position,locks[device][position],position_locks[position]))
|
||
|
thread.start()
|
||
|
threads.append(thread)
|
||
|
|
||
|
def checkpoint(frames):
|
||
|
if flags.disable_checkpoint:
|
||
|
return
|
||
|
log.info('Saving checkpoint to %s', checkpointpath)
|
||
|
_models = learner_model.get_models()
|
||
|
torch.save({
|
||
|
'model_state_dict': {k: _models[k].state_dict() for k in _models}, # {{"general": _models["landlord"].state_dict()}
|
||
|
'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers}, # {"general": optimizers["landlord"].state_dict()}
|
||
|
"stats": stats,
|
||
|
'flags': vars(flags),
|
||
|
'frames': frames,
|
||
|
'position_frames': position_frames
|
||
|
}, checkpointpath)
|
||
|
|
||
|
# Save the weights for evaluation purpose
|
||
|
for position in ['landlord', 'landlord_up', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_down']
|
||
|
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||
|
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt')))
|
||
|
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||
|
|
||
|
fps_log = []
|
||
|
timer = timeit.default_timer
|
||
|
try:
|
||
|
last_checkpoint_time = timer() - flags.save_interval * 60
|
||
|
while frames < flags.total_frames:
|
||
|
start_frames = frames
|
||
|
position_start_frames = {k: position_frames[k] for k in position_frames}
|
||
|
start_time = timer()
|
||
|
time.sleep(5)
|
||
|
|
||
|
if timer() - last_checkpoint_time > flags.save_interval * 60:
|
||
|
checkpoint(frames)
|
||
|
last_checkpoint_time = timer()
|
||
|
end_time = timer()
|
||
|
|
||
|
fps = (frames - start_frames) / (end_time - start_time)
|
||
|
fps_log.append(fps)
|
||
|
if len(fps_log) > 24:
|
||
|
fps_log = fps_log[1:]
|
||
|
fps_avg = np.mean(fps_log)
|
||
|
|
||
|
position_fps = {k:(position_frames[k]-position_start_frames[k])/(end_time-start_time) for k in position_frames}
|
||
|
log.info('After %i (L:%i U:%i D:%i) frames: @ %.1f fps (avg@ %.1f fps) (L:%.1f U:%.1f D:%.1f) Stats:\n%s',
|
||
|
frames,
|
||
|
position_frames['landlord'],
|
||
|
position_frames['landlord_up'],
|
||
|
position_frames['landlord_down'],
|
||
|
fps,
|
||
|
fps_avg,
|
||
|
position_fps['landlord'],
|
||
|
position_fps['landlord_up'],
|
||
|
position_fps['landlord_down'],
|
||
|
pprint.pformat(stats))
|
||
|
|
||
|
except KeyboardInterrupt:
|
||
|
return
|
||
|
else:
|
||
|
for thread in threads:
|
||
|
thread.join()
|
||
|
log.info('Learning finished after %d frames.', frames)
|
||
|
|
||
|
checkpoint(frames)
|
||
|
plogger.close()
|