移除无效参数

This commit is contained in:
ZaneYork 2022-01-02 22:33:48 +08:00
parent d76cac1335
commit 4571bb3dfc
4 changed files with 52 additions and 57 deletions

View File

@ -187,7 +187,7 @@ def train(flags):
for i in range(num_actors): for i in range(num_actors):
actor = mp.Process( actor = mp.Process(
target=act, target=act,
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, actor_model, flags)) args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags))
actor.daemon = True actor.daemon = True
actor.start() actor.start()
actor_processes.append({ actor_processes.append({
@ -306,14 +306,24 @@ def train(flags):
position_fps['landlord_front'], position_fps['landlord_front'],
position_fps['landlord_down'], position_fps['landlord_down'],
pprint.pformat(stats)) pprint.pformat(stats))
# for proc in actor_processes: for proc in actor_processes:
# if not proc['actor'].is_alive(): if not proc['actor'].is_alive():
# actor = mp.Process( i = proc['i']
# target=act, actor = mp.Process(
# args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame)) target=act,
# actor.daemon = True args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags))
# actor.start() actor.daemon = True
# proc['actor'] = actor actor.start()
proc['actor'] = actor
for proc in infer_processes:
if not proc['infer'].is_alive():
infer = mp.Process(
target=infer_logic,
args=(proc['i'], proc['device'], infer_queues, actor_model, flags, onnx_frame))
infer.daemon = True
infer.start()
proc['infer'] = actor
except KeyboardInterrupt: except KeyboardInterrupt:
flags.enable_upload = False flags.enable_upload = False

View File

@ -30,8 +30,8 @@ class Environment:
self.device = device self.device = device
self.episode_return = None self.episode_return = None
def initial(self, model, device, flags=None): def initial(self, flags=None):
obs = self.env.reset(model, device, flags=flags) obs = self.env.reset(flags=flags)
initial_position, initial_obs, x_no_action, z = _format_observation(obs) initial_position, initial_obs, x_no_action, z = _format_observation(obs)
self.episode_return = torch.zeros(1, 1) self.episode_return = torch.zeros(1, 1)
initial_done = torch.ones(1, 1, dtype=torch.bool) initial_done = torch.ones(1, 1, dtype=torch.bool)
@ -42,13 +42,13 @@ class Environment:
obs_z=z, obs_z=z,
) )
def step(self, action, model, device, flags=None): def step(self, action, flags=None):
obs, reward, done, _ = self.env.step(action) obs, reward, done, _ = self.env.step(action)
self.episode_return = reward self.episode_return = reward
episode_return = self.episode_return episode_return = self.episode_return
if done: if done:
obs = self.env.reset(model, device, flags=flags) obs = self.env.reset(flags=flags)
self.episode_return = torch.zeros(1, 1) self.episode_return = torch.zeros(1, 1)
position, obs, x_no_action, z = _format_observation(obs) position, obs, x_no_action, z = _format_observation(obs)

View File

@ -141,7 +141,7 @@ def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
if all_empty: if all_empty:
time.sleep(0.01) time.sleep(0.01)
def act_queue(i, infer_queue, batch_queues, actor_model, flags): def act_queue(i, infer_queue, batch_queues, flags):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
try: try:
T = flags.unroll_length T = flags.unroll_length
@ -164,7 +164,7 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34} position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
position, obs, env_output = env.initial(actor_model, device, flags=flags) position, obs, env_output = env.initial(flags=flags)
while True: while True:
while True: while True:
if len(obs['legal_actions']) > 1: if len(obs['legal_actions']) > 1:
@ -189,7 +189,7 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
x_batch = env_output['obs_x_no_action'].float() x_batch = env_output['obs_x_no_action'].float()
obs_x_batch_buf[position].append(x_batch) obs_x_batch_buf[position].append(x_batch)
type_buf[position].append(position_index[position]) type_buf[position].append(position_index[position])
position, obs, env_output = env.step(action, actor_model, device, flags=flags) position, obs, env_output = env.step(action, flags=flags)
size[position] += 1 size[position] += 1
if env_output['done']: if env_output['done']:
for p in positions: for p in positions:
@ -243,12 +243,12 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
print() print()
raise e raise e
def act(i, infer_queues, batch_queues, actor_model, flags): def act(i, infer_queues, batch_queues, flags):
threads = [] threads = []
for x in range(len(infer_queues)): for x in range(len(infer_queues)):
thread = threading.Thread( thread = threading.Thread(
target=act_queue, name='act_queue-%d-%d' % (i, x), target=act_queue, name='act_queue-%d-%d' % (i, x),
args=(x, infer_queues[x], batch_queues, actor_model, flags)) args=(x, infer_queues[x], batch_queues, flags))
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
threads.append(thread) threads.append(thread)

63
douzero/env/env.py vendored
View File

@ -91,7 +91,7 @@ class Env:
self.total_round = 0 self.total_round = 0
self.infoset = None self.infoset = None
def reset(self, model, device, flags=None): def reset(self, flags=None):
""" """
Every time reset is called, the environment Every time reset is called, the environment
will be re-initialized with a new deck of cards. will be re-initialized with a new deck of cards.
@ -100,46 +100,31 @@ class Env:
self._env.reset() self._env.reset()
# Randomly shuffle the deck # Randomly shuffle the deck
if model is None: self.total_round += 1
_deck = deck.copy() _deck = deck.copy()
np.random.shuffle(_deck) np.random.shuffle(_deck)
card_play_data = {'landlord': _deck[:33], card_play_data = {'landlord': _deck[:33],
'landlord_up': _deck[33:58], 'landlord_up': _deck[33:58],
'landlord_front': _deck[58:83], 'landlord_front': _deck[58:83],
'landlord_down': _deck[83:108], 'landlord_down': _deck[83:108],
# 'three_landlord_cards': _deck[17:20], }
} for key in card_play_data:
for key in card_play_data: card_play_data[key].sort()
card_play_data[key].sort() player_ids = {
self._env.card_play_init(card_play_data) 'landlord': 0,
self.infoset = self._game_infoset 'landlord_down': 1,
return get_obs(self.infoset, self.use_general, self.lite_model) 'landlord_front': 2,
else: 'landlord_up': 3,
self.total_round += 1 }
_deck = deck.copy()
np.random.shuffle(_deck)
card_play_data = {'landlord': _deck[:33],
'landlord_up': _deck[33:58],
'landlord_front': _deck[58:83],
'landlord_down': _deck[83:108],
}
for key in card_play_data:
card_play_data[key].sort()
player_ids = {
'landlord': 0,
'landlord_down': 1,
'landlord_front': 2,
'landlord_up': 3,
}
# Initialize the cards # Initialize the cards
self._env.card_play_init(card_play_data) self._env.card_play_init(card_play_data)
for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]: for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
pid = player_ids[pos] pid = player_ids[pos]
self._env.info_sets[pos].player_id = pid self._env.info_sets[pos].player_id = pid
self.infoset = self._game_infoset self.infoset = self._game_infoset
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model) return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
def step(self, action): def step(self, action):
""" """