新增vanilla模型训练逻辑及开关项
This commit is contained in:
parent
3300fd9658
commit
05aa179ba6
|
@ -23,6 +23,8 @@ parser.add_argument('--training_device', default='0', type=str,
|
|||
help='The index of the GPU used for training models. `cpu` means using cpu')
|
||||
parser.add_argument('--load_model', action='store_true',
|
||||
help='Load an existing model')
|
||||
parser.add_argument('--old_model', action='store_true',
|
||||
help='Use vanilla model')
|
||||
parser.add_argument('--disable_checkpoint', action='store_true',
|
||||
help='Disable saving checkpoint')
|
||||
parser.add_argument('--savedir', default='douzero_checkpoints',
|
||||
|
|
|
@ -15,6 +15,7 @@ import douzero.env.env
|
|||
from .file_writer import FileWriter
|
||||
from .models import Model, OldModel
|
||||
from .utils import get_batch, log, create_env, create_optimizers, act
|
||||
import psutil
|
||||
|
||||
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']}
|
||||
|
||||
|
@ -33,8 +34,14 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock):
|
|||
device = torch.device('cuda:'+str(flags.training_device))
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
obs_x = batch["obs_x_batch"]
|
||||
obs_x = torch.flatten(obs_x, 0, 1).to(device)
|
||||
if flags.old_model and position != 'bidding':
|
||||
obs_x_no_action = batch['obs_x_no_action'].to(device)
|
||||
obs_action = batch['obs_x_batch'].to(device)
|
||||
obs_x = torch.cat((obs_x_no_action, obs_action), dim=2).float()
|
||||
obs_x = torch.flatten(obs_x, 0, 1)
|
||||
else:
|
||||
obs_x = batch["obs_x_batch"]
|
||||
obs_x = torch.flatten(obs_x, 0, 1).to(device)
|
||||
obs_z = torch.flatten(batch['obs_z'].to(device), 0, 1).float()
|
||||
target = torch.flatten(batch['target'].to(device), 0, 1)
|
||||
if position != "bidding":
|
||||
|
@ -94,7 +101,10 @@ def train(flags):
|
|||
# Initialize actor models
|
||||
models = {}
|
||||
for device in device_iterator:
|
||||
model = Model(device="cpu")
|
||||
if flags.old_model:
|
||||
model = OldModel(device="cpu")
|
||||
else:
|
||||
model = Model(device="cpu")
|
||||
model.share_memory()
|
||||
model.eval()
|
||||
models[device] = model
|
||||
|
@ -105,7 +115,10 @@ def train(flags):
|
|||
batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue(), "bidding": ctx.SimpleQueue()}
|
||||
|
||||
# Learner model for training
|
||||
learner_model = Model(device=flags.training_device)
|
||||
if flags.old_model:
|
||||
learner_model = OldModel(device=flags.training_device)
|
||||
else:
|
||||
learner_model = Model(device=flags.training_device)
|
||||
|
||||
# Create optimizers
|
||||
optimizers = create_optimizers(flags, learner_model)
|
||||
|
@ -159,6 +172,11 @@ def train(flags):
|
|||
actor.start()
|
||||
actor_processes.append(actor)
|
||||
|
||||
parent = psutil.Process()
|
||||
parent.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
|
||||
for child in parent.children():
|
||||
child.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
|
||||
|
||||
def batch_and_learn(i, device, position, local_lock, position_lock, lock=threading.Lock()):
|
||||
"""Thread target for the learning process."""
|
||||
nonlocal frames, position_frames, stats
|
||||
|
|
|
@ -13,7 +13,7 @@ class LandlordLstmModel(nn.Module):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
||||
self.dense1 = nn.Linear(846 + 128, 1024)
|
||||
self.dense1 = nn.Linear(860 + 128, 1024)
|
||||
self.dense2 = nn.Linear(1024, 1024)
|
||||
self.dense3 = nn.Linear(1024, 768)
|
||||
self.dense4 = nn.Linear(768, 512)
|
||||
|
@ -48,7 +48,7 @@ class FarmerLstmModel(nn.Module):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
||||
self.dense1 = nn.Linear(1178 + 128, 1024)
|
||||
self.dense1 = nn.Linear(1192 + 128, 1024)
|
||||
self.dense2 = nn.Linear(1024, 1024)
|
||||
self.dense3 = nn.Linear(1024, 768)
|
||||
self.dense4 = nn.Linear(768, 512)
|
||||
|
@ -255,14 +255,14 @@ class GeneralModel(nn.Module):
|
|||
self.in_planes = 80
|
||||
#input 1*108*41
|
||||
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
|
||||
stride=(2,), padding=1, bias=False) #1*54*80
|
||||
stride=(2,), padding=1, bias=False) #1*108*80
|
||||
|
||||
self.bn1 = nn.BatchNorm1d(80)
|
||||
|
||||
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*27*80
|
||||
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*14*160
|
||||
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
|
||||
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*320
|
||||
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
|
||||
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048)
|
||||
self.linear2 = nn.Linear(2048, 1024)
|
||||
|
@ -416,6 +416,7 @@ class OldModel:
|
|||
self.models['landlord_up'] = FarmerLstmModel().to(torch.device(device))
|
||||
self.models['landlord_front'] = FarmerLstmModel().to(torch.device(device))
|
||||
self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device))
|
||||
self.models['bidding'] = BidModel().to(torch.device(device))
|
||||
|
||||
def forward(self, position, z, x, training=False, flags=None):
|
||||
model = self.models[position]
|
||||
|
@ -426,12 +427,14 @@ class OldModel:
|
|||
self.models['landlord_up'].share_memory()
|
||||
self.models['landlord_front'].share_memory()
|
||||
self.models['landlord_down'].share_memory()
|
||||
self.models['bidding'].share_memory()
|
||||
|
||||
def eval(self):
|
||||
self.models['landlord'].eval()
|
||||
self.models['landlord_up'].eval()
|
||||
self.models['landlord_front'].eval()
|
||||
self.models['landlord_down'].eval()
|
||||
self.models['bidding'].eval()
|
||||
|
||||
def parameters(self, position):
|
||||
return self.models[position].parameters()
|
||||
|
|
|
@ -41,7 +41,7 @@ log.setLevel(logging.INFO)
|
|||
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
||||
|
||||
def create_env(flags):
|
||||
return Env(flags.objective)
|
||||
return Env(flags.objective, flags.old_model)
|
||||
|
||||
def get_batch(b_queues, position, flags, lock):
|
||||
"""
|
||||
|
@ -55,7 +55,7 @@ def get_batch(b_queues, position, flags, lock):
|
|||
buffer.append(b_queue.get())
|
||||
batch = {
|
||||
key: torch.stack([m[key] for m in buffer], dim=1)
|
||||
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_type"]
|
||||
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
|
||||
}
|
||||
del buffer
|
||||
return batch
|
||||
|
@ -93,6 +93,9 @@ def act(i, device, batch_queues, model, flags):
|
|||
size = {p: 0 for p in positions}
|
||||
type_buf = {p: [] for p in positions}
|
||||
obs_x_batch_buf = {p: [] for p in positions}
|
||||
if flags.old_model:
|
||||
obs_action_buf = {p: [] for p in positions}
|
||||
obs_x_no_action = {p: [] for p in positions}
|
||||
|
||||
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
|
||||
bid_type_index = {"landlord": 41, "landlord_up": 42, "landlord_front": 43, "landlord_down": 44}
|
||||
|
@ -119,10 +122,15 @@ def act(i, device, batch_queues, model, flags):
|
|||
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
||||
_action_idx = int(agent_output['action'].cpu().detach().numpy())
|
||||
action = obs['legal_actions'][_action_idx]
|
||||
obs_z_buf[position].append(torch.vstack((_cards2tensor(action).unsqueeze(0), env_output['obs_z'])).float())
|
||||
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
|
||||
x_batch = env_output['obs_x_no_action'].float()
|
||||
obs_x_batch_buf[position].append(x_batch)
|
||||
if flags.old_model and position != 'bidding':
|
||||
obs_action_buf[position].append(_cards2tensor(action))
|
||||
obs_x_no_action[position].append(env_output['obs_x_no_action'])
|
||||
obs_z_buf[position].append(env_output['obs_z'])
|
||||
else:
|
||||
obs_z_buf[position].append(torch.vstack((_cards2tensor(action).unsqueeze(0), env_output['obs_z'])).float())
|
||||
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
|
||||
x_batch = env_output['obs_x_no_action'].float()
|
||||
obs_x_batch_buf[position].append(x_batch)
|
||||
type_buf[position].append(position_index[position])
|
||||
position, obs, env_output = env.step(action, model, device, flags=flags)
|
||||
size[position] += 1
|
||||
|
@ -155,18 +163,31 @@ def act(i, device, batch_queues, model, flags):
|
|||
for p in positions:
|
||||
while size[p] > T:
|
||||
# print(p, "epr", torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),)
|
||||
batch_queues[p].put({
|
||||
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
|
||||
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
|
||||
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
|
||||
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
|
||||
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_batch_buf[p][:T]]),
|
||||
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
|
||||
})
|
||||
if flags.old_model and p != 'bidding':
|
||||
batch_queues[p].put({
|
||||
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
|
||||
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
|
||||
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
|
||||
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
|
||||
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_action_buf[p][:T]]),
|
||||
"obs_x_no_action": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_no_action[p][:T]]),
|
||||
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
|
||||
})
|
||||
obs_action_buf[p] = obs_action_buf[p][T:]
|
||||
obs_x_no_action[p] = obs_x_no_action[p][T:]
|
||||
else:
|
||||
batch_queues[p].put({
|
||||
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
|
||||
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
|
||||
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
|
||||
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
|
||||
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_batch_buf[p][:T]]),
|
||||
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
|
||||
})
|
||||
obs_x_batch_buf[p] = obs_x_batch_buf[p][T:]
|
||||
done_buf[p] = done_buf[p][T:]
|
||||
episode_return_buf[p] = episode_return_buf[p][T:]
|
||||
target_buf[p] = target_buf[p][T:]
|
||||
obs_x_batch_buf[p] = obs_x_batch_buf[p][T:]
|
||||
obs_z_buf[p] = obs_z_buf[p][T:]
|
||||
type_buf[p] = type_buf[p][T:]
|
||||
size[p] -= T
|
||||
|
|
|
@ -34,7 +34,7 @@ class Env:
|
|||
Doudizhu multi-agent wrapper
|
||||
"""
|
||||
|
||||
def __init__(self, objective):
|
||||
def __init__(self, objective, old_model):
|
||||
"""
|
||||
Objective is wp/adp/logadp. It indicates whether considers
|
||||
bomb in reward calculation. Here, we use dummy agents.
|
||||
|
@ -47,6 +47,7 @@ class Env:
|
|||
will perform the actual action in the game engine.
|
||||
"""
|
||||
self.objective = objective
|
||||
self.use_general = not old_model
|
||||
|
||||
# Initialize players
|
||||
# We use three dummy player for the target position
|
||||
|
@ -82,7 +83,7 @@ class Env:
|
|||
card_play_data[key].sort()
|
||||
self._env.card_play_init(card_play_data)
|
||||
self.infoset = self._game_infoset
|
||||
return get_obs(self.infoset)
|
||||
return get_obs(self.infoset, self.use_general)
|
||||
else:
|
||||
self.total_round += 1
|
||||
bid_done = False
|
||||
|
@ -232,7 +233,7 @@ class Env:
|
|||
print("发牌情况: %i/%i %.1f%%" % (self.force_bid, self.total_round, self.force_bid / self.total_round * 100))
|
||||
self.force_bid = 0
|
||||
self.total_round = 0
|
||||
return get_obs(self.infoset), {
|
||||
return get_obs(self.infoset, self.use_general), {
|
||||
"bid_obs_buffer": bid_obs_buffer,
|
||||
"multiply_obs_buffer": multiply_obs_buffer
|
||||
}
|
||||
|
@ -269,7 +270,7 @@ class Env:
|
|||
}
|
||||
obs = None
|
||||
else:
|
||||
obs = get_obs(self.infoset)
|
||||
obs = get_obs(self.infoset, self.use_general)
|
||||
return obs, reward, done, {}
|
||||
|
||||
def _get_reward(self, pos):
|
||||
|
|
Loading…
Reference in New Issue