修复eval逻辑BUG
This commit is contained in:
parent
6dcfe074de
commit
82e941a5eb
|
@ -0,0 +1 @@
|
||||||
|
python train.py --load_model --batch_size 8 --learning_rate 0.0003
|
|
@ -19,6 +19,8 @@ parser.add_argument('--num_actor_devices', default=1, type=int,
|
||||||
help='The number of devices used for simulation')
|
help='The number of devices used for simulation')
|
||||||
parser.add_argument('--num_actors', default=2, type=int,
|
parser.add_argument('--num_actors', default=2, type=int,
|
||||||
help='The number of actors for each simulation device')
|
help='The number of actors for each simulation device')
|
||||||
|
parser.add_argument('--num_actors_cpu', default=1, type=int,
|
||||||
|
help='The number of actors for each simulation device')
|
||||||
parser.add_argument('--training_device', default='0', type=str,
|
parser.add_argument('--training_device', default='0', type=str,
|
||||||
help='The index of the GPU used for training models. `cpu` means using cpu')
|
help='The index of the GPU used for training models. `cpu` means using cpu')
|
||||||
parser.add_argument('--load_model', action='store_true',
|
parser.add_argument('--load_model', action='store_true',
|
||||||
|
|
|
@ -16,6 +16,7 @@ from .file_writer import FileWriter
|
||||||
from .models import Model, OldModel
|
from .models import Model, OldModel
|
||||||
from .utils import get_batch, log, create_env, create_optimizers, act
|
from .utils import get_batch, log, create_env, create_optimizers, act
|
||||||
import psutil
|
import psutil
|
||||||
|
import shutil
|
||||||
|
|
||||||
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']}
|
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']}
|
||||||
|
|
||||||
|
@ -95,7 +96,7 @@ def train(flags):
|
||||||
if flags.actor_device_cpu:
|
if flags.actor_device_cpu:
|
||||||
device_iterator = ['cpu']
|
device_iterator = ['cpu']
|
||||||
else:
|
else:
|
||||||
device_iterator = range(flags.num_actor_devices)
|
device_iterator = range(flags.num_actor_devices) #[0, 'cpu']
|
||||||
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
||||||
|
|
||||||
# Initialize actor models
|
# Initialize actor models
|
||||||
|
@ -163,8 +164,11 @@ def train(flags):
|
||||||
|
|
||||||
# Starting actor processes
|
# Starting actor processes
|
||||||
for device in device_iterator:
|
for device in device_iterator:
|
||||||
|
if device == 'cpu':
|
||||||
|
num_actors = flags.num_actors_cpu
|
||||||
|
else:
|
||||||
num_actors = flags.num_actors
|
num_actors = flags.num_actors
|
||||||
for i in range(flags.num_actors):
|
for i in range(num_actors):
|
||||||
actor = mp.Process(
|
actor = mp.Process(
|
||||||
target=act,
|
target=act,
|
||||||
args=(i, device, batch_queues, models[device], flags))
|
args=(i, device, batch_queues, models[device], flags))
|
||||||
|
@ -220,13 +224,15 @@ def train(flags):
|
||||||
'flags': vars(flags),
|
'flags': vars(flags),
|
||||||
'frames': frames,
|
'frames': frames,
|
||||||
'position_frames': position_frames
|
'position_frames': position_frames
|
||||||
}, checkpointpath)
|
}, checkpointpath + '.new')
|
||||||
|
|
||||||
# Save the weights for evaluation purpose
|
# Save the weights for evaluation purpose
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt')))
|
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt')))
|
||||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||||
|
shutil.move(checkpointpath + '.new', checkpointpath)
|
||||||
|
|
||||||
|
|
||||||
fps_log = []
|
fps_log = []
|
||||||
timer = timeit.default_timer
|
timer = timeit.default_timer
|
||||||
|
|
|
@ -127,7 +127,7 @@ def data_allocation_per_worker(card_play_data_list, num_workers):
|
||||||
|
|
||||||
return card_play_data_list_each_worker
|
return card_play_data_list_each_worker
|
||||||
|
|
||||||
def evaluate(landlord, landlord_up, landlord_down, eval_data, num_workers, output, output_bid, title):
|
def evaluate(landlord, landlord_up, landlord_front, landlord_down, eval_data, num_workers, output, output_bid, title):
|
||||||
|
|
||||||
with open(eval_data, 'rb') as f:
|
with open(eval_data, 'rb') as f:
|
||||||
card_play_data_list = pickle.load(f)
|
card_play_data_list = pickle.load(f)
|
||||||
|
@ -139,7 +139,7 @@ def evaluate(landlord, landlord_up, landlord_down, eval_data, num_workers, outpu
|
||||||
card_play_model_path_dict = {
|
card_play_model_path_dict = {
|
||||||
'landlord': landlord,
|
'landlord': landlord,
|
||||||
'landlord_up': landlord_up,
|
'landlord_up': landlord_up,
|
||||||
'landlord_front': landlord_up,
|
'landlord_front': landlord_front,
|
||||||
'landlord_down': landlord_down}
|
'landlord_down': landlord_down}
|
||||||
|
|
||||||
num_landlord_wins = 0
|
num_landlord_wins = 0
|
||||||
|
|
|
@ -38,6 +38,7 @@ def make_evaluate(args, t, frame, adp_frame, folder_a = 'baselines', folder_b =
|
||||||
|
|
||||||
evaluate(args.landlord,
|
evaluate(args.landlord,
|
||||||
args.landlord_up,
|
args.landlord_up,
|
||||||
|
args.landlord_front,
|
||||||
args.landlord_down,
|
args.landlord_down,
|
||||||
args.eval_data,
|
args.eval_data,
|
||||||
args.num_workers,
|
args.num_workers,
|
||||||
|
@ -59,7 +60,7 @@ if __name__ == '__main__':
|
||||||
default='baselines/douzero_12/landlord_down_weights_39762328900.ckpt')
|
default='baselines/douzero_12/landlord_down_weights_39762328900.ckpt')
|
||||||
parser.add_argument('--eval_data', type=str,
|
parser.add_argument('--eval_data', type=str,
|
||||||
default='eval_data_200.pkl')
|
default='eval_data_200.pkl')
|
||||||
parser.add_argument('--num_workers', type=int, default=5)
|
parser.add_argument('--num_workers', type=int, default=2)
|
||||||
parser.add_argument('--gpu_device', type=str, default='0')
|
parser.add_argument('--gpu_device', type=str, default='0')
|
||||||
parser.add_argument('--output', type=bool, default=True)
|
parser.add_argument('--output', type=bool, default=True)
|
||||||
parser.add_argument('--bid', type=bool, default=True)
|
parser.add_argument('--bid', type=bool, default=True)
|
||||||
|
@ -107,7 +108,7 @@ if __name__ == '__main__':
|
||||||
# [14102400, 4968800, 'baselines', 'baselines'],
|
# [14102400, 4968800, 'baselines', 'baselines'],
|
||||||
# [14102400, 13252000, 'baselines', 'baselines2'],
|
# [14102400, 13252000, 'baselines', 'baselines2'],
|
||||||
# [14102400, 15096800, 'baselines', 'baselines2'],
|
# [14102400, 15096800, 'baselines', 'baselines2'],
|
||||||
[14102400, 14102400, 'baselines', 'baselines'],
|
[34828000, 40132800, 'baselines2', 'baselines2'],
|
||||||
# [14102400, None, 'baselines', 'baselines'],
|
# [14102400, None, 'baselines', 'baselines'],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from douzero.dmc.file_writer import FileWriter
|
||||||
|
from douzero.dmc.models import Model, OldModel
|
||||||
|
from douzero.dmc.utils import get_batch, log, create_env, create_optimizers, act
|
||||||
|
|
||||||
|
learner_model = Model(device="0")
|
||||||
|
# lr=flags.learning_rate,
|
||||||
|
# momentum=flags.momentum,
|
||||||
|
# eps=flags.epsilon,
|
||||||
|
# alpha=flags.alpha)
|
||||||
|
class myflags:
|
||||||
|
learning_rate=0.0003
|
||||||
|
momentum=0
|
||||||
|
alpha=0.99
|
||||||
|
epsilon=1e-5
|
||||||
|
|
||||||
|
flags = myflags
|
||||||
|
checkpointpath = "merger/model.tar"
|
||||||
|
merged_path = "merger/model_merged.tar"
|
||||||
|
optimizers = create_optimizers(flags, learner_model)
|
||||||
|
models = {}
|
||||||
|
device_iterator = ["cpu"]
|
||||||
|
for device in device_iterator:
|
||||||
|
model = Model(device="cpu")
|
||||||
|
model.share_memory()
|
||||||
|
model.eval()
|
||||||
|
models[device] = model
|
||||||
|
checkpoint_states = torch.load(
|
||||||
|
checkpointpath
|
||||||
|
)
|
||||||
|
print("Load original weights")
|
||||||
|
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_down']
|
||||||
|
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||||
|
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
||||||
|
stats = checkpoint_states["stats"]
|
||||||
|
|
||||||
|
print("Load replace weights")
|
||||||
|
for k in ['landlord']:
|
||||||
|
if not os.path.exists("merger/resnet_" + k + ".ckpt"):
|
||||||
|
continue
|
||||||
|
weights = torch.load("merger/resnet_" + k + ".ckpt", map_location="cuda:0")
|
||||||
|
learner_model.get_model(k).load_state_dict(weights)
|
||||||
|
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||||
|
|
||||||
|
frames = checkpoint_states["frames"]
|
||||||
|
# frames = 3085177900
|
||||||
|
position_frames = checkpoint_states["position_frames"]
|
||||||
|
log.info(f"Resuming preempted job, current stats:\n{stats}")
|
||||||
|
|
||||||
|
log.info('Saving checkpoint to %s', checkpointpath)
|
||||||
|
_models = learner_model.get_models()
|
||||||
|
|
||||||
|
torch.save({
|
||||||
|
'model_state_dict': {k: _models[k].state_dict() for k in _models}, # {{"general": _models["landlord"].state_dict()}
|
||||||
|
'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers}, # {"general": optimizers["landlord"].state_dict()}
|
||||||
|
"stats": stats,
|
||||||
|
'flags': checkpoint_states["flags"],
|
||||||
|
'frames': frames,
|
||||||
|
'position_frames': position_frames
|
||||||
|
}, merged_path)
|
Loading…
Reference in New Issue