2021-09-07 16:38:34 +08:00
import os
import threading
import time
import timeit
import pprint
from collections import deque
import numpy as np
import torch
from torch import multiprocessing as mp
from torch import nn
import douzero . dmc . models
import douzero . env . env
from . file_writer import FileWriter
2022-01-04 18:15:35 +08:00
from . models import Model , OldModel , UnifiedModel
from . utils import get_batch , log , create_optimizers , act , infer_logic
2021-12-10 16:12:50 +08:00
import psutil
2021-12-11 20:07:41 +08:00
import shutil
2021-12-25 19:06:34 +08:00
import requests
2021-09-07 16:38:34 +08:00
2021-12-19 17:19:32 +08:00
mean_episode_return_buf = { p : deque ( maxlen = 100 ) for p in [ ' landlord ' , ' landlord_up ' , ' landlord_front ' , ' landlord_down ' ] }
2021-09-07 16:38:34 +08:00
def compute_loss ( logits , targets ) :
loss = ( ( logits . squeeze ( - 1 ) - targets ) * * 2 ) . mean ( )
return loss
2022-01-01 23:21:14 +08:00
def learn ( position , actor_model , model , batch , optimizer , flags , lock ) :
2021-09-07 16:38:34 +08:00
""" Performs a learning (optimization) step. """
2021-12-05 12:03:30 +08:00
position_index = { " landlord " : 31 , " landlord_up " : 32 , ' landlord_front ' : 33 , " landlord_down " : 34 }
2021-09-07 16:38:34 +08:00
print ( " Learn " , position )
if flags . training_device != " cpu " :
device = torch . device ( ' cuda: ' + str ( flags . training_device ) )
else :
device = torch . device ( ' cpu ' )
2021-12-19 17:19:32 +08:00
if flags . old_model :
2021-12-10 16:12:50 +08:00
obs_x_no_action = batch [ ' obs_x_no_action ' ] . to ( device )
obs_action = batch [ ' obs_x_batch ' ] . to ( device )
obs_x = torch . cat ( ( obs_x_no_action , obs_action ) , dim = 2 ) . float ( )
obs_x = torch . flatten ( obs_x , 0 , 1 )
else :
obs_x = batch [ " obs_x_batch " ]
obs_x = torch . flatten ( obs_x , 0 , 1 ) . to ( device )
2021-09-07 16:38:34 +08:00
obs_z = torch . flatten ( batch [ ' obs_z ' ] . to ( device ) , 0 , 1 ) . float ( )
target = torch . flatten ( batch [ ' target ' ] . to ( device ) , 0 , 1 )
2021-12-19 17:19:32 +08:00
episode_returns = batch [ ' episode_return ' ] [ batch [ ' done ' ] & ( batch [ " obs_type " ] == position_index [ position ] ) ]
2021-09-07 16:38:34 +08:00
if len ( episode_returns ) > 0 :
mean_episode_return_buf [ position ] . append ( torch . mean ( episode_returns ) . to ( device ) )
with lock :
2021-12-14 23:06:07 +08:00
learner_outputs = model ( obs_z , obs_x )
2021-12-19 17:19:32 +08:00
loss = compute_loss ( learner_outputs [ ' values ' ] , target )
2021-09-07 16:38:34 +08:00
stats = {
' mean_episode_return_ ' + position : torch . mean ( torch . stack ( [ _r for _r in mean_episode_return_buf [ position ] ] ) ) . item ( ) ,
' loss_ ' + position : loss . item ( ) ,
}
optimizer . zero_grad ( )
loss . backward ( )
nn . utils . clip_grad_norm_ ( model . parameters ( ) , flags . max_grad_norm )
optimizer . step ( )
2021-12-15 22:09:18 +08:00
if not flags . enable_onnx :
2022-01-01 23:21:14 +08:00
actor_model . get_model ( position ) . load_state_dict ( model . state_dict ( ) )
2021-09-07 16:38:34 +08:00
return stats
2021-12-14 22:55:03 +08:00
def train ( flags ) :
2021-09-07 16:38:34 +08:00
"""
This is the main funtion for training . It will first
initilize everything , such as buffers , optimizers , etc .
Then it will start subprocesses as actors . Then , it will call
learning function with multiple threads .
"""
2022-01-04 11:12:36 +08:00
if flags . training_device != ' cpu ' or flags . infer_devices != ' cpu ' :
2021-09-07 16:38:34 +08:00
if not torch . cuda . is_available ( ) :
2022-01-04 11:12:36 +08:00
raise AssertionError ( " CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --infer_devices cpu --training_device cpu` " )
2021-09-07 16:38:34 +08:00
plogger = FileWriter (
xpid = flags . xpid ,
xp_args = flags . __dict__ ,
rootdir = flags . savedir ,
)
checkpointpath = os . path . expandvars (
os . path . expanduser ( ' %s / %s / %s ' % ( flags . savedir , flags . xpid , ' model.tar ' ) ) )
T = flags . unroll_length
B = flags . batch_size
# Initialize actor models
2022-01-01 23:21:14 +08:00
if flags . old_model :
actor_model = OldModel ( device = " cpu " , flags = flags , lite_model = flags . lite_model )
2022-01-04 18:15:35 +08:00
elif flags . unified_model :
actor_model = UnifiedModel ( device = " cpu " , flags = flags , lite_model = flags . lite_model )
2022-01-01 23:21:14 +08:00
else :
actor_model = Model ( device = " cpu " , flags = flags , lite_model = flags . lite_model )
actor_model . eval ( )
2021-09-07 16:38:34 +08:00
# Initialize queues
actor_processes = [ ]
ctx = mp . get_context ( ' spawn ' )
2021-12-22 09:04:56 +08:00
batch_queues = { " landlord " : ctx . SimpleQueue ( ) , " landlord_up " : ctx . SimpleQueue ( ) , ' landlord_front ' : ctx . SimpleQueue ( ) , " landlord_down " : ctx . SimpleQueue ( ) }
2021-12-15 10:03:26 +08:00
onnx_frame = ctx . Value ( ' d ' , - 1 )
2021-09-07 16:38:34 +08:00
# Learner model for training
2021-12-10 16:12:50 +08:00
if flags . old_model :
2021-12-23 09:55:49 +08:00
learner_model = OldModel ( device = flags . training_device , lite_model = flags . lite_model )
2022-01-04 18:15:35 +08:00
elif flags . unified_model :
learner_model = UnifiedModel ( device = flags . training_device , lite_model = flags . lite_model )
2021-12-10 16:12:50 +08:00
else :
2021-12-22 21:19:10 +08:00
learner_model = Model ( device = flags . training_device , lite_model = flags . lite_model )
2021-09-07 16:38:34 +08:00
# Stat Keys
stat_keys = [
' mean_episode_return_landlord ' ,
' loss_landlord ' ,
' mean_episode_return_landlord_up ' ,
' loss_landlord_up ' ,
2021-12-05 12:03:30 +08:00
' mean_episode_return_landlord_front ' ,
' loss_landlord_front ' ,
2021-09-07 16:38:34 +08:00
' mean_episode_return_landlord_down ' ,
2021-12-19 17:19:32 +08:00
' loss_landlord_down '
2021-09-07 16:38:34 +08:00
]
frames , stats = 0 , { k : 0 for k in stat_keys }
2021-12-19 17:19:32 +08:00
position_frames = { ' landlord ' : 0 , ' landlord_up ' : 0 , ' landlord_front ' : 0 , ' landlord_down ' : 0 }
2021-12-22 09:24:06 +08:00
position_locks = { ' landlord ' : threading . Lock ( ) , ' landlord_up ' : threading . Lock ( ) , ' landlord_front ' : threading . Lock ( ) , ' landlord_down ' : threading . Lock ( ) }
2021-09-07 16:38:34 +08:00
2021-12-19 18:49:08 +08:00
def sync_onnx_model ( frames ) :
2021-12-21 12:49:53 +08:00
p_path = ' %s / %s ' % ( flags . onnx_model_path , flags . xpid )
if not os . path . exists ( p_path ) :
os . makedirs ( p_path )
2021-12-19 18:49:08 +08:00
for position in [ ' landlord ' , ' landlord_up ' , ' landlord_front ' , ' landlord_down ' ] :
2021-12-21 10:46:24 +08:00
if flags . enable_onnx :
model_path = ' %s / %s /model_ %s .onnx ' % ( flags . onnx_model_path , flags . xpid , position )
2021-12-19 18:49:08 +08:00
onnx_params = learner_model . get_model ( position ) \
. get_onnx_params ( torch . device ( ' cpu ' ) if flags . training_device == ' cpu ' else torch . device ( ' cuda: ' + flags . training_device ) )
2021-12-22 09:24:06 +08:00
with position_locks [ position ] :
torch . onnx . export (
learner_model . get_model ( position ) ,
onnx_params [ ' args ' ] ,
model_path ,
export_params = True ,
opset_version = 10 ,
do_constant_folding = True ,
input_names = onnx_params [ ' input_names ' ] ,
output_names = onnx_params [ ' output_names ' ] ,
dynamic_axes = onnx_params [ ' dynamic_axes ' ]
)
2021-12-19 18:49:08 +08:00
onnx_frame . value = frames
2022-01-01 23:21:14 +08:00
# Create optimizers
optimizers = create_optimizers ( flags , learner_model )
2021-09-07 16:38:34 +08:00
# Load models if any
if flags . load_model and os . path . exists ( checkpointpath ) :
checkpoint_states = torch . load (
checkpointpath , map_location = ( " cuda: " + str ( flags . training_device ) if flags . training_device != " cpu " else " cpu " )
)
2021-12-19 17:19:32 +08:00
for k in [ ' landlord ' , ' landlord_up ' , ' landlord_front ' , ' landlord_down ' ] : # ['landlord', 'landlord_up', 'landlord_down']
2021-09-07 16:38:34 +08:00
learner_model . get_model ( k ) . load_state_dict ( checkpoint_states [ " model_state_dict " ] [ k ] )
optimizers [ k ] . load_state_dict ( checkpoint_states [ " optimizer_state_dict " ] [ k ] )
2021-12-15 22:09:18 +08:00
if not flags . enable_onnx :
2022-01-01 23:21:14 +08:00
actor_model . get_model ( k ) . load_state_dict ( checkpoint_states [ " model_state_dict " ] [ k ] )
2021-09-07 16:38:34 +08:00
stats = checkpoint_states [ " stats " ]
frames = checkpoint_states [ " frames " ]
position_frames = checkpoint_states [ " position_frames " ]
2021-12-19 18:49:08 +08:00
sync_onnx_model ( frames )
2021-09-07 16:38:34 +08:00
log . info ( f " Resuming preempted job, current stats: \n { stats } " )
2022-01-01 23:21:14 +08:00
infer_queues = [ ]
num_actors = flags . num_actors
for j in range ( flags . num_actors_thread ) :
2021-12-11 20:07:41 +08:00
for i in range ( num_actors ) :
2022-01-01 23:21:14 +08:00
infer_queues . append ( {
2022-01-01 23:36:03 +08:00
' input ' : ctx . Queue ( maxsize = 100 ) , ' output ' : ctx . Queue ( maxsize = 100 )
2022-01-01 23:21:14 +08:00
} )
infer_processes = [ ]
for device in flags . infer_devices . split ( ' , ' ) :
2022-01-01 23:36:03 +08:00
for i in range ( flags . num_infer if device != ' cpu ' else 1 ) :
2022-01-01 23:21:14 +08:00
infer = mp . Process (
target = infer_logic ,
args = ( i , device , infer_queues , actor_model , flags , onnx_frame ) )
infer . daemon = True
infer . start ( )
infer_processes . append ( {
2021-12-21 12:49:53 +08:00
' device ' : device ,
' i ' : i ,
2022-01-01 23:21:14 +08:00
' infer ' : infer
2021-12-21 12:49:53 +08:00
} )
2021-09-07 16:38:34 +08:00
2022-01-01 23:21:14 +08:00
# Starting actor processes
for i in range ( num_actors ) :
actor = mp . Process (
target = act ,
2022-01-02 22:33:48 +08:00
args = ( i , infer_queues [ i * 4 : ( i + 1 ) * 4 ] , batch_queues , flags ) )
2022-01-01 23:21:14 +08:00
actor . daemon = True
actor . start ( )
actor_processes . append ( {
' i ' : i ,
' actor ' : actor
} )
2021-12-10 16:12:50 +08:00
parent = psutil . Process ( )
2021-12-23 14:28:03 +08:00
parent . nice ( psutil . NORMAL_PRIORITY_CLASS )
2021-12-10 16:12:50 +08:00
for child in parent . children ( ) :
child . nice ( psutil . BELOW_NORMAL_PRIORITY_CLASS )
2021-12-21 10:46:24 +08:00
def batch_and_learn ( i , position , local_lock , position_lock , lock = threading . Lock ( ) ) :
2021-09-07 16:38:34 +08:00
""" Thread target for the learning process. """
nonlocal frames , position_frames , stats
while frames < flags . total_frames :
batch = get_batch ( batch_queues , position , flags , local_lock )
2022-01-01 23:21:14 +08:00
_stats = learn ( position , actor_model , learner_model . get_model ( position ) , batch ,
2021-09-07 16:38:34 +08:00
optimizers [ position ] , flags , position_lock )
with lock :
for k in _stats :
stats [ k ] = _stats [ k ]
to_log = dict ( frames = frames )
to_log . update ( { k : stats [ k ] for k in stat_keys } )
plogger . log ( to_log )
frames + = T * B
position_frames [ position ] + = T * B
threads = [ ]
locks = { }
2022-01-01 23:21:14 +08:00
locks [ ' cpu ' ] = { ' landlord ' : threading . Lock ( ) , ' landlord_up ' : threading . Lock ( ) , ' landlord_front ' : threading . Lock ( ) , ' landlord_down ' : threading . Lock ( ) }
2021-09-07 16:38:34 +08:00
2021-12-21 10:46:24 +08:00
for i in range ( flags . num_threads ) :
for position in [ ' landlord ' , ' landlord_up ' , ' landlord_front ' , ' landlord_down ' ] :
thread = threading . Thread (
2022-01-01 23:21:14 +08:00
target = batch_and_learn , name = ' batch-and-learn- %d ' % i , args = ( i , position , locks [ ' cpu ' ] [ position ] , position_locks [ position ] ) )
2021-12-21 10:46:24 +08:00
thread . setDaemon ( True )
thread . start ( )
threads . append ( thread )
2021-12-14 22:55:03 +08:00
2021-09-07 16:38:34 +08:00
def checkpoint ( frames ) :
if flags . disable_checkpoint :
return
log . info ( ' Saving checkpoint to %s ' , checkpointpath )
_models = learner_model . get_models ( )
torch . save ( {
' model_state_dict ' : { k : _models [ k ] . state_dict ( ) for k in _models } , # {{"general": _models["landlord"].state_dict()}
' optimizer_state_dict ' : { k : optimizers [ k ] . state_dict ( ) for k in optimizers } , # {"general": optimizers["landlord"].state_dict()}
" stats " : stats ,
' flags ' : vars ( flags ) ,
' frames ' : frames ,
' position_frames ' : position_frames
2021-12-11 20:07:41 +08:00
} , checkpointpath + ' .new ' )
2021-09-07 16:38:34 +08:00
# Save the weights for evaluation purpose
2021-12-19 17:19:32 +08:00
for position in [ ' landlord ' , ' landlord_up ' , ' landlord_front ' , ' landlord_down ' ] :
2021-09-07 16:38:34 +08:00
model_weights_dir = os . path . expandvars ( os . path . expanduser (
2021-12-14 22:55:03 +08:00
' %s / %s / %s ' % ( flags . savedir , flags . xpid , " general_ " + position + ' _ ' + str ( frames ) + ' .ckpt ' ) ) )
2021-09-07 16:38:34 +08:00
torch . save ( learner_model . get_model ( position ) . state_dict ( ) , model_weights_dir )
2021-12-25 19:06:34 +08:00
if flags . enable_upload :
if flags . lite_model :
type = ' lite_ '
else :
type = ' '
if flags . old_model :
type + = ' vanilla '
2022-01-04 18:15:35 +08:00
elif flags . unified_model :
type + = ' unified '
2021-12-25 19:06:34 +08:00
else :
type + = ' resnet '
requests . post ( flags . upload_url , data = {
' type ' : type ,
' position ' : position ,
' frame ' : frames
} , files = { ' model_file ' : ( ' model.ckpt ' , open ( model_weights_dir , ' rb ' ) ) } )
os . remove ( model_weights_dir )
2021-12-11 20:07:41 +08:00
shutil . move ( checkpointpath + ' .new ' , checkpointpath )
2021-09-07 16:38:34 +08:00
fps_log = [ ]
timer = timeit . default_timer
try :
last_checkpoint_time = timer ( ) - flags . save_interval * 60
2021-12-19 18:49:08 +08:00
last_onnx_sync_time = timer ( )
2021-09-07 16:38:34 +08:00
while frames < flags . total_frames :
start_frames = frames
position_start_frames = { k : position_frames [ k ] for k in position_frames }
start_time = timer ( )
time . sleep ( 5 )
2021-12-14 22:55:03 +08:00
if timer ( ) - last_checkpoint_time > flags . save_interval * 60 :
2021-09-07 16:38:34 +08:00
checkpoint ( frames )
last_checkpoint_time = timer ( )
2021-12-19 18:49:08 +08:00
2021-12-19 19:09:27 +08:00
if timer ( ) - last_onnx_sync_time > flags . onnx_sync_interval :
2021-12-19 18:49:08 +08:00
sync_onnx_model ( frames )
last_onnx_sync_time = timer ( )
2021-09-07 16:38:34 +08:00
end_time = timer ( )
fps = ( frames - start_frames ) / ( end_time - start_time )
fps_log . append ( fps )
2021-12-22 09:04:56 +08:00
if len ( fps_log ) > 240 :
2021-09-07 16:38:34 +08:00
fps_log = fps_log [ 1 : ]
fps_avg = np . mean ( fps_log )
position_fps = { k : ( position_frames [ k ] - position_start_frames [ k ] ) / ( end_time - start_time ) for k in position_frames }
2021-12-05 12:03:30 +08:00
log . info ( ' After %i (L: %i U: %i F: %i D: %i ) frames: @ %.1f fps (avg@ %.1f fps) (L: %.1f U: %.1f F: %.1f D: %.1f ) Stats: \n %s ' ,
2021-09-07 16:38:34 +08:00
frames ,
position_frames [ ' landlord ' ] ,
position_frames [ ' landlord_up ' ] ,
2021-12-05 12:03:30 +08:00
position_frames [ ' landlord_front ' ] ,
2021-09-07 16:38:34 +08:00
position_frames [ ' landlord_down ' ] ,
fps ,
fps_avg ,
position_fps [ ' landlord ' ] ,
position_fps [ ' landlord_up ' ] ,
2021-12-05 12:03:30 +08:00
position_fps [ ' landlord_front ' ] ,
2021-09-07 16:38:34 +08:00
position_fps [ ' landlord_down ' ] ,
pprint . pformat ( stats ) )
2022-01-02 22:33:48 +08:00
for proc in actor_processes :
if not proc [ ' actor ' ] . is_alive ( ) :
i = proc [ ' i ' ]
actor = mp . Process (
target = act ,
args = ( i , infer_queues [ i * 4 : ( i + 1 ) * 4 ] , batch_queues , flags ) )
actor . daemon = True
actor . start ( )
proc [ ' actor ' ] = actor
for proc in infer_processes :
if not proc [ ' infer ' ] . is_alive ( ) :
infer = mp . Process (
target = infer_logic ,
args = ( proc [ ' i ' ] , proc [ ' device ' ] , infer_queues , actor_model , flags , onnx_frame ) )
infer . daemon = True
infer . start ( )
proc [ ' infer ' ] = actor
2021-09-07 16:38:34 +08:00
except KeyboardInterrupt :
2021-12-25 19:06:34 +08:00
flags . enable_upload = False
2021-12-19 20:00:42 +08:00
checkpoint ( frames )
2021-12-14 22:55:03 +08:00
return
2021-09-07 16:38:34 +08:00
else :
for thread in threads :
thread . join ( )
log . info ( ' Learning finished after %d frames. ' , frames )
checkpoint ( frames )
plogger . close ( )