2021-09-07 16:38:34 +08:00
import os
import threading
import time
import timeit
import pprint
from collections import deque
import numpy as np
import torch
from torch import multiprocessing as mp
from torch import nn
import douzero . dmc . models
import douzero . env . env
from . file_writer import FileWriter
from . models import Model , OldModel
from . utils import get_batch , log , create_env , create_optimizers , act
2021-12-10 16:12:50 +08:00
import psutil
2021-12-11 20:07:41 +08:00
import shutil
2021-09-07 16:38:34 +08:00
2021-12-05 12:03:30 +08:00
mean_episode_return_buf = { p : deque ( maxlen = 100 ) for p in [ ' landlord ' , ' landlord_up ' , ' landlord_front ' , ' landlord_down ' , ' bidding ' ] }
2021-12-14 22:55:03 +08:00
onnx_frame = mp . Value ( ' d ' , - 1 )
2021-09-07 16:38:34 +08:00
def compute_loss ( logits , targets ) :
loss = ( ( logits . squeeze ( - 1 ) - targets ) * * 2 ) . mean ( )
return loss
def compute_loss_for_bid ( outputs , reward ) :
pass
def learn ( position , actor_models , model , batch , optimizer , flags , lock ) :
""" Performs a learning (optimization) step. """
2021-12-05 12:03:30 +08:00
position_index = { " landlord " : 31 , " landlord_up " : 32 , ' landlord_front ' : 33 , " landlord_down " : 34 }
2021-09-07 16:38:34 +08:00
print ( " Learn " , position )
if flags . training_device != " cpu " :
device = torch . device ( ' cuda: ' + str ( flags . training_device ) )
else :
device = torch . device ( ' cpu ' )
2021-12-10 16:12:50 +08:00
if flags . old_model and position != ' bidding ' :
obs_x_no_action = batch [ ' obs_x_no_action ' ] . to ( device )
obs_action = batch [ ' obs_x_batch ' ] . to ( device )
obs_x = torch . cat ( ( obs_x_no_action , obs_action ) , dim = 2 ) . float ( )
obs_x = torch . flatten ( obs_x , 0 , 1 )
else :
obs_x = batch [ " obs_x_batch " ]
obs_x = torch . flatten ( obs_x , 0 , 1 ) . to ( device )
2021-09-07 16:38:34 +08:00
obs_z = torch . flatten ( batch [ ' obs_z ' ] . to ( device ) , 0 , 1 ) . float ( )
target = torch . flatten ( batch [ ' target ' ] . to ( device ) , 0 , 1 )
if position != " bidding " :
episode_returns = batch [ ' episode_return ' ] [ batch [ ' done ' ] & ( batch [ " obs_type " ] == position_index [ position ] ) ]
else :
2021-12-07 17:44:36 +08:00
episode_returns = batch [ ' episode_return ' ] [ batch [ ' done ' ] & ( ( batch [ " obs_type " ] == 41 ) | ( batch [ " obs_type " ] == 42 ) | ( batch [ " obs_type " ] == 43 ) | ( batch [ " obs_type " ] == 44 ) ) ]
2021-09-07 16:38:34 +08:00
if len ( episode_returns ) > 0 :
mean_episode_return_buf [ position ] . append ( torch . mean ( episode_returns ) . to ( device ) )
with lock :
learner_outputs = model ( obs_z , obs_x , return_value = True )
if position == " bidding " :
2021-12-05 12:03:30 +08:00
loss = compute_loss ( learner_outputs [ ' values ' ] , target )
# pass
2021-09-07 16:38:34 +08:00
else :
loss = compute_loss ( learner_outputs [ ' values ' ] , target )
stats = {
' mean_episode_return_ ' + position : torch . mean ( torch . stack ( [ _r for _r in mean_episode_return_buf [ position ] ] ) ) . item ( ) ,
' loss_ ' + position : loss . item ( ) ,
}
optimizer . zero_grad ( )
loss . backward ( )
nn . utils . clip_grad_norm_ ( model . parameters ( ) , flags . max_grad_norm )
optimizer . step ( )
for actor_model in actor_models . values ( ) :
actor_model . get_model ( position ) . load_state_dict ( model . state_dict ( ) )
return stats
2021-12-14 22:55:03 +08:00
def train ( flags ) :
2021-09-07 16:38:34 +08:00
"""
This is the main funtion for training . It will first
initilize everything , such as buffers , optimizers , etc .
Then it will start subprocesses as actors . Then , it will call
learning function with multiple threads .
"""
2021-12-14 22:55:03 +08:00
global onnx_frame
2021-09-07 16:38:34 +08:00
if not flags . actor_device_cpu or flags . training_device != ' cpu ' :
if not torch . cuda . is_available ( ) :
raise AssertionError ( " CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu` " )
plogger = FileWriter (
xpid = flags . xpid ,
xp_args = flags . __dict__ ,
rootdir = flags . savedir ,
)
checkpointpath = os . path . expandvars (
os . path . expanduser ( ' %s / %s / %s ' % ( flags . savedir , flags . xpid , ' model.tar ' ) ) )
T = flags . unroll_length
B = flags . batch_size
if flags . actor_device_cpu :
device_iterator = [ ' cpu ' ]
else :
2021-12-11 20:07:41 +08:00
device_iterator = range ( flags . num_actor_devices ) #[0, 'cpu']
2021-09-07 16:38:34 +08:00
assert flags . num_actor_devices < = len ( flags . gpu_devices . split ( ' , ' ) ) , ' The number of actor devices can not exceed the number of available devices '
# Initialize actor models
models = { }
for device in device_iterator :
2021-12-10 16:12:50 +08:00
if flags . old_model :
model = OldModel ( device = " cpu " )
else :
model = Model ( device = " cpu " )
2021-09-07 16:38:34 +08:00
model . share_memory ( )
model . eval ( )
models [ device ] = model
# Initialize queues
actor_processes = [ ]
ctx = mp . get_context ( ' spawn ' )
2021-12-05 12:03:30 +08:00
batch_queues = { " landlord " : ctx . SimpleQueue ( ) , " landlord_up " : ctx . SimpleQueue ( ) , ' landlord_front ' : ctx . SimpleQueue ( ) , " landlord_down " : ctx . SimpleQueue ( ) , " bidding " : ctx . SimpleQueue ( ) }
2021-09-07 16:38:34 +08:00
# Learner model for training
2021-12-10 16:12:50 +08:00
if flags . old_model :
learner_model = OldModel ( device = flags . training_device )
else :
learner_model = Model ( device = flags . training_device )
2021-09-07 16:38:34 +08:00
# Create optimizers
optimizers = create_optimizers ( flags , learner_model )
# Stat Keys
stat_keys = [
' mean_episode_return_landlord ' ,
' loss_landlord ' ,
' mean_episode_return_landlord_up ' ,
' loss_landlord_up ' ,
2021-12-05 12:03:30 +08:00
' mean_episode_return_landlord_front ' ,
' loss_landlord_front ' ,
2021-09-07 16:38:34 +08:00
' mean_episode_return_landlord_down ' ,
' loss_landlord_down ' ,
' mean_episode_return_bidding ' ,
' loss_bidding ' ,
]
frames , stats = 0 , { k : 0 for k in stat_keys }
2021-12-05 12:03:30 +08:00
position_frames = { ' landlord ' : 0 , ' landlord_up ' : 0 , ' landlord_front ' : 0 , ' landlord_down ' : 0 , ' bidding ' : 0 }
2021-09-07 16:38:34 +08:00
# Load models if any
if flags . load_model and os . path . exists ( checkpointpath ) :
checkpoint_states = torch . load (
checkpointpath , map_location = ( " cuda: " + str ( flags . training_device ) if flags . training_device != " cpu " else " cpu " )
)
2021-12-05 12:03:30 +08:00
for k in [ ' landlord ' , ' landlord_up ' , ' landlord_front ' , ' landlord_down ' , ' bidding ' ] : # ['landlord', 'landlord_up', 'landlord_down']
2021-09-07 16:38:34 +08:00
learner_model . get_model ( k ) . load_state_dict ( checkpoint_states [ " model_state_dict " ] [ k ] )
optimizers [ k ] . load_state_dict ( checkpoint_states [ " optimizer_state_dict " ] [ k ] )
for device in device_iterator :
models [ device ] . get_model ( k ) . load_state_dict ( checkpoint_states [ " model_state_dict " ] [ k ] )
stats = checkpoint_states [ " stats " ]
if not ' mean_episode_return_bidding ' in stats :
stats . update ( { " mean_episode_return_bidding " : 0 } )
if not ' loss_bidding ' in stats :
stats . update ( { " loss_bidding " : 0 } )
frames = checkpoint_states [ " frames " ]
position_frames = checkpoint_states [ " position_frames " ]
if not " bidding " in position_frames :
position_frames . update ( { " bidding " : 0 } )
log . info ( f " Resuming preempted job, current stats: \n { stats } " )
# Starting actor processes
for device in device_iterator :
2021-12-11 20:07:41 +08:00
if device == ' cpu ' :
num_actors = flags . num_actors_cpu
else :
num_actors = flags . num_actors
for i in range ( num_actors ) :
2021-09-07 16:38:34 +08:00
actor = mp . Process (
target = act ,
2021-12-14 22:55:03 +08:00
args = ( i , device , batch_queues , models [ device ] , flags , onnx_frame ) )
2021-12-14 14:19:09 +08:00
actor . daemon = True
2021-09-07 16:38:34 +08:00
actor . start ( )
actor_processes . append ( actor )
2021-12-10 16:12:50 +08:00
parent = psutil . Process ( )
parent . nice ( psutil . BELOW_NORMAL_PRIORITY_CLASS )
for child in parent . children ( ) :
child . nice ( psutil . BELOW_NORMAL_PRIORITY_CLASS )
2021-09-07 16:38:34 +08:00
def batch_and_learn ( i , device , position , local_lock , position_lock , lock = threading . Lock ( ) ) :
""" Thread target for the learning process. """
nonlocal frames , position_frames , stats
while frames < flags . total_frames :
batch = get_batch ( batch_queues , position , flags , local_lock )
2021-12-14 22:55:03 +08:00
_stats = learn ( position , models , learner_model . get_model ( position ) , batch ,
2021-09-07 16:38:34 +08:00
optimizers [ position ] , flags , position_lock )
with lock :
for k in _stats :
stats [ k ] = _stats [ k ]
to_log = dict ( frames = frames )
to_log . update ( { k : stats [ k ] for k in stat_keys } )
plogger . log ( to_log )
frames + = T * B
position_frames [ position ] + = T * B
threads = [ ]
locks = { }
for device in device_iterator :
2021-12-05 12:03:30 +08:00
locks [ device ] = { ' landlord ' : threading . Lock ( ) , ' landlord_up ' : threading . Lock ( ) , ' landlord_front ' : threading . Lock ( ) , ' landlord_down ' : threading . Lock ( ) , ' bidding ' : threading . Lock ( ) }
position_locks = { ' landlord ' : threading . Lock ( ) , ' landlord_up ' : threading . Lock ( ) , ' landlord_front ' : threading . Lock ( ) , ' landlord_down ' : threading . Lock ( ) , ' bidding ' : threading . Lock ( ) }
2021-09-07 16:38:34 +08:00
for device in device_iterator :
for i in range ( flags . num_threads ) :
2021-12-05 12:03:30 +08:00
for position in [ ' landlord ' , ' landlord_up ' , ' landlord_front ' , ' landlord_down ' , ' bidding ' ] :
2021-09-07 16:38:34 +08:00
thread = threading . Thread (
target = batch_and_learn , name = ' batch-and-learn- %d ' % i , args = ( i , device , position , locks [ device ] [ position ] , position_locks [ position ] ) )
2021-12-14 14:19:09 +08:00
thread . setDaemon ( True )
2021-09-07 16:38:34 +08:00
thread . start ( )
threads . append ( thread )
2021-12-14 22:55:03 +08:00
2021-09-07 16:38:34 +08:00
def checkpoint ( frames ) :
2021-12-14 22:55:03 +08:00
global onnx_frame
2021-09-07 16:38:34 +08:00
if flags . disable_checkpoint :
return
log . info ( ' Saving checkpoint to %s ' , checkpointpath )
_models = learner_model . get_models ( )
torch . save ( {
' model_state_dict ' : { k : _models [ k ] . state_dict ( ) for k in _models } , # {{"general": _models["landlord"].state_dict()}
' optimizer_state_dict ' : { k : optimizers [ k ] . state_dict ( ) for k in optimizers } , # {"general": optimizers["landlord"].state_dict()}
" stats " : stats ,
' flags ' : vars ( flags ) ,
' frames ' : frames ,
' position_frames ' : position_frames
2021-12-11 20:07:41 +08:00
} , checkpointpath + ' .new ' )
2021-09-07 16:38:34 +08:00
# Save the weights for evaluation purpose
2021-12-14 22:55:03 +08:00
dummy_input = (
torch . tensor ( np . zeros ( [ 1 , 40 , 108 ] ) , dtype = torch . float32 ) ,
torch . tensor ( np . zeros ( ( 1 , 80 ) ) , dtype = torch . float32 )
)
for position in [ ' landlord ' , ' landlord_up ' , ' landlord_front ' , ' landlord_down ' , ' bidding ' ] :
2021-09-07 16:38:34 +08:00
model_weights_dir = os . path . expandvars ( os . path . expanduser (
2021-12-14 22:55:03 +08:00
' %s / %s / %s ' % ( flags . savedir , flags . xpid , " general_ " + position + ' _ ' + str ( frames ) + ' .ckpt ' ) ) )
2021-09-07 16:38:34 +08:00
torch . save ( learner_model . get_model ( position ) . state_dict ( ) , model_weights_dir )
2021-12-14 18:21:01 +08:00
if position != ' bidding ' :
2021-12-14 22:55:03 +08:00
model_path = ' %s / %s /model_ %s .onnx ' % ( flags . savedir , flags . xpid , position )
2021-12-14 18:21:01 +08:00
torch . onnx . export (
learner_model . get_model ( position ) ,
dummy_input ,
2021-12-14 22:55:03 +08:00
model_path ,
input_names = [ ' z_batch ' , ' x_batch ' ] ,
output_names = [ ' values ' ] ,
dynamic_axes = {
' z_batch ' : {
0 : " legal_actions "
} ,
' x_batch ' : {
0 : " legal_actions "
}
}
2021-12-14 18:21:01 +08:00
)
2021-12-14 22:55:03 +08:00
onnx_frame = frames
2021-12-11 20:07:41 +08:00
shutil . move ( checkpointpath + ' .new ' , checkpointpath )
2021-09-07 16:38:34 +08:00
fps_log = [ ]
timer = timeit . default_timer
try :
last_checkpoint_time = timer ( ) - flags . save_interval * 60
while frames < flags . total_frames :
start_frames = frames
position_start_frames = { k : position_frames [ k ] for k in position_frames }
start_time = timer ( )
time . sleep ( 5 )
2021-12-14 22:55:03 +08:00
if timer ( ) - last_checkpoint_time > flags . save_interval * 60 :
2021-09-07 16:38:34 +08:00
checkpoint ( frames )
last_checkpoint_time = timer ( )
end_time = timer ( )
fps = ( frames - start_frames ) / ( end_time - start_time )
fps_log . append ( fps )
if len ( fps_log ) > 24 :
fps_log = fps_log [ 1 : ]
fps_avg = np . mean ( fps_log )
position_fps = { k : ( position_frames [ k ] - position_start_frames [ k ] ) / ( end_time - start_time ) for k in position_frames }
2021-12-05 12:03:30 +08:00
log . info ( ' After %i (L: %i U: %i F: %i D: %i ) frames: @ %.1f fps (avg@ %.1f fps) (L: %.1f U: %.1f F: %.1f D: %.1f ) Stats: \n %s ' ,
2021-09-07 16:38:34 +08:00
frames ,
position_frames [ ' landlord ' ] ,
position_frames [ ' landlord_up ' ] ,
2021-12-05 12:03:30 +08:00
position_frames [ ' landlord_front ' ] ,
2021-09-07 16:38:34 +08:00
position_frames [ ' landlord_down ' ] ,
fps ,
fps_avg ,
position_fps [ ' landlord ' ] ,
position_fps [ ' landlord_up ' ] ,
2021-12-05 12:03:30 +08:00
position_fps [ ' landlord_front ' ] ,
2021-09-07 16:38:34 +08:00
position_fps [ ' landlord_down ' ] ,
pprint . pformat ( stats ) )
except KeyboardInterrupt :
2021-12-14 22:55:03 +08:00
return
2021-09-07 16:38:34 +08:00
else :
for thread in threads :
thread . join ( )
log . info ( ' Learning finished after %d frames. ' , frames )
checkpoint ( frames )
plogger . close ( )