diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index bef2003..1365313 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -19,6 +19,7 @@ import psutil import shutil mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']} +onnx_frame = mp.Value('d', -1) def compute_loss(logits, targets): loss = ((logits.squeeze(-1) - targets)**2).mean() @@ -72,13 +73,14 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock): actor_model.get_model(position).load_state_dict(model.state_dict()) return stats -def train(flags): +def train(flags): """ This is the main funtion for training. It will first initilize everything, such as buffers, optimizers, etc. Then it will start subprocesses as actors. Then, it will call learning function with multiple threads. """ + global onnx_frame if not flags.actor_device_cpu or flags.training_device != 'cpu': if not torch.cuda.is_available(): raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`") @@ -171,7 +173,7 @@ def train(flags): for i in range(num_actors): actor = mp.Process( target=act, - args=(i, device, batch_queues, models[device], flags)) + args=(i, device, batch_queues, models[device], flags, onnx_frame)) actor.daemon = True actor.start() actor_processes.append(actor) @@ -186,7 +188,7 @@ def train(flags): nonlocal frames, position_frames, stats while frames < flags.total_frames: batch = get_batch(batch_queues, position, flags, local_lock) - _stats = learn(position, models, learner_model.get_model(position), batch, + _stats = learn(position, models, learner_model.get_model(position), batch, optimizers[position], flags, position_lock) with lock: for k in _stats: @@ -212,8 +214,9 @@ def train(flags): thread.setDaemon(True) thread.start() threads.append(thread) - + def checkpoint(frames): + global onnx_frame if flags.disable_checkpoint: return log.info('Saving checkpoint to %s', checkpointpath) @@ -228,25 +231,32 @@ def train(flags): }, checkpointpath + '.new') # Save the weights for evaluation purpose - dummy_input = (torch.tensor(np.zeros([80, 40, 108]), dtype=torch.float32), - torch.tensor(np.zeros((80, 80)), dtype=torch.int8), - { - 'return_value': False, - 'flags': {'exp_epsilon':0.001} - }, - ) - for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] + dummy_input = ( + torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32), + torch.tensor(np.zeros((1, 80)), dtype=torch.float32) + ) + for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: model_weights_dir = os.path.expandvars(os.path.expanduser( - '%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt'))) + '%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt'))) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) if position != 'bidding': + model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position) torch.onnx.export( learner_model.get_model(position), dummy_input, - '%s/model_%s.onnx' % (flags.savedir, position), - input_names=['z_batch','x_batch','flags'], - output_names=['action', 'max_value'] + model_path, + input_names=['z_batch','x_batch'], + output_names=['values'], + dynamic_axes={ + 'z_batch': { + 0: "legal_actions" + }, + 'x_batch': { + 0: "legal_actions" + } + } ) + onnx_frame = frames shutil.move(checkpointpath + '.new', checkpointpath) @@ -260,7 +270,7 @@ def train(flags): start_time = timer() time.sleep(5) - if timer() - last_checkpoint_time > flags.save_interval * 60: + if timer() - last_checkpoint_time > flags.save_interval * 60: checkpoint(frames) last_checkpoint_time = timer() end_time = timer() @@ -287,7 +297,7 @@ def train(flags): pprint.pformat(stats)) except KeyboardInterrupt: - return + return else: for thread in threads: thread.join() diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index e2be986..42aa19d 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -6,9 +6,14 @@ models into one class for convenience. import numpy as np import torch +import onnxruntime +from onnxruntime.datasets import get_example from torch import nn import torch.nn.functional as F +def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + class LandlordLstmModel(nn.Module): def __init__(self): super().__init__() @@ -20,7 +25,7 @@ class LandlordLstmModel(nn.Module): self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) - def forward(self, z, x, return_value=False, flags=None): + def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) @@ -35,14 +40,7 @@ class LandlordLstmModel(nn.Module): x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) - if return_value: - return dict(values=x) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - action = torch.randint(x.shape[0], (1,))[0] - else: - action = torch.argmax(x,dim=0)[0] - return dict(action=action) + return dict(values=x) class FarmerLstmModel(nn.Module): def __init__(self): @@ -55,7 +53,7 @@ class FarmerLstmModel(nn.Module): self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) - def forward(self, z, x, return_value=False, flags=None): + def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) @@ -70,14 +68,7 @@ class FarmerLstmModel(nn.Module): x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) - if return_value: - return dict(values=x) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - action = torch.randint(x.shape[0], (1,))[0] - else: - action = torch.argmax(x,dim=0)[0] - return dict(action=action) + return dict(values=x) class LandlordLstmModelLegacy(nn.Module): def __init__(self): @@ -90,7 +81,7 @@ class LandlordLstmModelLegacy(nn.Module): self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) - def forward(self, z, x, return_value=False, flags=None): + def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) @@ -105,14 +96,7 @@ class LandlordLstmModelLegacy(nn.Module): x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) - if return_value: - return dict(values=x) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - action = torch.randint(x.shape[0], (1,))[0] - else: - action = torch.argmax(x,dim=0)[0] - return dict(action=action) + return dict(values=x) class FarmerLstmModelLegacy(nn.Module): def __init__(self): @@ -125,7 +109,7 @@ class FarmerLstmModelLegacy(nn.Module): self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) - def forward(self, z, x, return_value=False, flags=None): + def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) @@ -140,14 +124,7 @@ class FarmerLstmModelLegacy(nn.Module): x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) - if return_value: - return dict(values=x) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - action = torch.randint(x.shape[0], (1,))[0] - else: - action = torch.argmax(x,dim=0)[0] - return dict(action=action) + return dict(values=x) class GeneralModel1(nn.Module): def __init__(self): @@ -185,7 +162,7 @@ class GeneralModel1(nn.Module): self.dense5 = nn.Linear(512, 512) self.dense6 = nn.Linear(512, 1) - def forward(self, z, x, return_value=False, flags=None, debug=False): + def forward(self, z, x): z = z.unsqueeze(1) z = self.conv_z_1(z) z = z.squeeze(-1) @@ -209,14 +186,7 @@ class GeneralModel1(nn.Module): x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) - if return_value: - return dict(values=x) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - action = torch.randint(x.shape[0], (1,))[0] - else: - action = torch.argmax(x,dim=0)[0] - return dict(action=action, max_value=torch.max(x)) + return dict(values=x) # 用于ResNet18和34的残差块,用的是2个3x3的卷积 @@ -278,7 +248,7 @@ class GeneralModelLegacy(nn.Module): self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def forward(self, z, x, return_value=False, flags=None, debug=False): + def forward(self, z, x): out = F.relu(self.bn1(self.conv1(z))) out = self.layer1(out) out = self.layer2(out) @@ -291,14 +261,7 @@ class GeneralModelLegacy(nn.Module): out = F.leaky_relu_(self.linear3(out)) out = F.leaky_relu_(self.linear4(out)) out = F.leaky_relu_(self.linear5(out)) - if return_value: - return dict(values=out) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - action = torch.randint(out.shape[0], (1,))[0] - else: - action = torch.argmax(out,dim=0)[0] - return dict(action=action, max_value=torch.max(out)) + return dict(values=out) class GeneralModel(nn.Module): def __init__(self): @@ -329,7 +292,7 @@ class GeneralModel(nn.Module): self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def forward(self, z, x, return_value=False, flags=None, debug=False): + def forward(self, z, x): out = F.relu(self.bn1(self.conv1(z))) out = self.layer1(out) out = self.layer2(out) @@ -342,17 +305,7 @@ class GeneralModel(nn.Module): out = F.leaky_relu_(self.linear3(out)) out = F.leaky_relu_(self.linear4(out)) out = F.leaky_relu_(self.linear5(out)) - if return_value: - return dict(values=out) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - action = torch.randint(out.shape[0], (1,))[0] - else: - action = torch.argmax(out,dim=0)[0] - return dict(action=action, max_value=torch.max(out)) - - - + return dict(values=out) class BidModel(nn.Module): @@ -366,7 +319,7 @@ class BidModel(nn.Module): self.dense5 = nn.Linear(512, 512) self.dense6 = nn.Linear(512, 1) - def forward(self, z, x, return_value=False, flags=None, debug=False): + def forward(self, z, x): x = self.dense1(x) x = F.leaky_relu(x) # x = F.relu(x) @@ -383,14 +336,7 @@ class BidModel(nn.Module): # x = F.relu(x) x = F.leaky_relu(x) x = self.dense6(x) - if return_value: - return dict(values=x) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - action = torch.randint(x.shape[0], (1,))[0] - else: - action = torch.argmax(x,dim=0)[0] - return dict(action=action, max_value=torch.max(x)) + return dict(values=x) # Model dict is only used in evaluation but not training @@ -438,9 +384,17 @@ class General_Model: self.models['landlord_down'] = GeneralModel1().to(torch.device(device)) self.models['bidding'] = BidModel().to(torch.device(device)) - def forward(self, position, z, x, training=False, flags=None, debug=False): + def forward(self, position, z, x, return_value=False, flags=None, debug=False): model = self.models[position] - return model.forward(z, x, training, flags, debug) + values = model.forward(z, x)['values'] + if return_value: + return dict(values=values) + else: + if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: + action = torch.randint(values.shape[0], (1,))[0] + else: + action = torch.argmax(values,dim=0)[0] + return dict(action=action, max_value=torch.max(values)) def share_memory(self): self.models['landlord'].share_memory() @@ -480,9 +434,17 @@ class OldModel: self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device)) self.models['bidding'] = BidModel().to(torch.device(device)) - def forward(self, position, z, x, training=False, flags=None): + def forward(self, position, z, x, return_value=False, flags=None): model = self.models[position] - return model.forward(z, x, training, flags) + values = model.forward(z, x)['values'] + if return_value: + return dict(values=values) + else: + if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: + action = torch.randint(values.shape[0], (1,))[0] + else: + action = torch.argmax(values,dim=0)[0] + return dict(action=action, max_value=torch.max(values)) def share_memory(self): self.models['landlord'].share_memory() @@ -523,10 +485,35 @@ class Model: self.models['landlord_front'] = GeneralModel().to(torch.device(device)) self.models['landlord_down'] = GeneralModel().to(torch.device(device)) self.models['bidding'] = BidModel().to(torch.device(device)) + self.onnx_models = { + 'landlord': None, + 'landlord_up': None, + 'landlord_front': None, + 'landlord_down': None, + 'bidding': None + } + self.models['bidding'] = BidModel().to(torch.device(device)) - def forward(self, position, z, x, training=False, flags=None, debug=False): - model = self.models[position] - return model.forward(z, x, training, flags, debug) + def set_onnx_model(self, position, model_path): + self.onnx_models[position] = get_example(model_path) + + def forward(self, position, z, x, return_value=False, flags=None, debug=False): + model = self.onnx_models[position] + if model is None: + model = self.models[position] + values = model.forward(z, x)['values'] + else: + sess = onnxruntime.InferenceSession(model) + onnx_out = sess.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)}) + values = torch.tensor(onnx_out[0]) + if return_value: + return dict(values=values) + else: + if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: + action = torch.randint(values.shape[0], (1,))[0] + else: + action = torch.argmax(values,dim=0)[0] + return dict(action=action, max_value=torch.max(values)) def share_memory(self): self.models['landlord'].share_memory() diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index 4002be0..dd7c0f2 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -81,7 +81,7 @@ def create_optimizers(flags, learner_model): return optimizers -def act(i, device, batch_queues, model, flags): +def act(i, device, batch_queues, model, flags, onnx_frame): positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding'] for pos in positions: model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device)))) @@ -110,8 +110,16 @@ def act(i, device, batch_queues, model, flags): position, obs, env_output = env.initial(model, device, flags=flags) bid_obs_buffer = env_output["begin_buf"]["bid_obs_buffer"] multiply_obs_buffer = env_output["begin_buf"]["multiply_obs_buffer"] + last_onnx_frame = -1 while True: # print("posi", position) + if onnx_frame != last_onnx_frame: + last_onnx_frame = onnx_frame + for p in positions: + if p != 'bidding': + model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, p) + model.set_onnx_model(p, os.path.abspath(model_path)) + for bid_obs in bid_obs_buffer: obs_z_buf["bidding"].append(bid_obs['z_batch']) obs_x_batch_buf["bidding"].append(bid_obs["x_batch"]) diff --git a/evaluate.py b/evaluate.py index 08d4f3c..ad276b7 100644 --- a/evaluate.py +++ b/evaluate.py @@ -61,8 +61,8 @@ if __name__ == '__main__': parser.add_argument('--bid', type=bool, default=True) parser.add_argument('--title', type=str, default='New') args = parser.parse_args() - # args.output = True - args.output = False + args.output = True + # args.output = False args.bid = False if args.output or args.bid: args.num_workers = 1 @@ -72,15 +72,30 @@ if __name__ == '__main__': eval_list = [ # { - # 'landlord': { 'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 48545600}, - # 'farmer': { 'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400}, - # 'two_way': True + # 'landlord': {'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 143539200}, + # 'farmer': {'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 143539200}, + # 'two_way': False # }, { - 'landlord': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400}, - 'farmer': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400}, + 'farmer': { 'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 143539200}, + 'landlord': { 'folder': 'baselines', 'prefix': 'resnet', 'frame': 23358800}, 'two_way': True - } + }, + # { + # 'landlord': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400}, + # 'farmer': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400}, + # 'two_way': False + # }, + # { + # 'landlord': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400}, + # 'farmer': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400}, + # 'two_way': False + # }, + # { + # 'landlord': {'folder': 'baselines', 'prefix': 'legacy_resnet', 'frame': 11754400}, + # 'farmer': {'folder': 'baselines', 'prefix': 'resnet', 'frame': 11534400}, + # 'two_way': True + # }, ] for vs in reversed(eval_list): diff --git a/generate_eval_data_with_bid.py b/generate_eval_data_with_bid.py index 1855dc8..2c4701b 100644 --- a/generate_eval_data_with_bid.py +++ b/generate_eval_data_with_bid.py @@ -15,8 +15,8 @@ deck.extend([20, 20, 30, 30]) def get_parser(): parser = argparse.ArgumentParser(description='DouZero: random data generator') - parser.add_argument('--output', default='eval_data_200', type=str) - parser.add_argument('--path', default='baselines/resnet_bidding_27853600.ckpt', type=str) + parser.add_argument('--output', default='eval_data_500', type=str) + parser.add_argument('--path', default='baselines/resnet_landlord_23358800.ckpt', type=str) parser.add_argument('--num_games', default=200, type=int) parser.add_argument('--exp_epsilon', default=0.01, type=float) return parser diff --git a/requirements.txt b/requirements.txt index 0e25259..41b5f02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ GitPython gitdb2 rlcard psutil +onnx +onnxruntime