diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index eddbd26..cf8ed83 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -187,7 +187,7 @@ def train(flags): for i in range(num_actors): actor = mp.Process( target=act, - args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, actor_model, flags)) + args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags)) actor.daemon = True actor.start() actor_processes.append({ @@ -306,14 +306,24 @@ def train(flags): position_fps['landlord_front'], position_fps['landlord_down'], pprint.pformat(stats)) - # for proc in actor_processes: - # if not proc['actor'].is_alive(): - # actor = mp.Process( - # target=act, - # args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame)) - # actor.daemon = True - # actor.start() - # proc['actor'] = actor + for proc in actor_processes: + if not proc['actor'].is_alive(): + i = proc['i'] + actor = mp.Process( + target=act, + args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags)) + actor.daemon = True + actor.start() + proc['actor'] = actor + + for proc in infer_processes: + if not proc['infer'].is_alive(): + infer = mp.Process( + target=infer_logic, + args=(proc['i'], proc['device'], infer_queues, actor_model, flags, onnx_frame)) + infer.daemon = True + infer.start() + proc['infer'] = actor except KeyboardInterrupt: flags.enable_upload = False diff --git a/douzero/dmc/env_utils.py b/douzero/dmc/env_utils.py index 96fbf0c..ce1de20 100644 --- a/douzero/dmc/env_utils.py +++ b/douzero/dmc/env_utils.py @@ -30,8 +30,8 @@ class Environment: self.device = device self.episode_return = None - def initial(self, model, device, flags=None): - obs = self.env.reset(model, device, flags=flags) + def initial(self, flags=None): + obs = self.env.reset(flags=flags) initial_position, initial_obs, x_no_action, z = _format_observation(obs) self.episode_return = torch.zeros(1, 1) initial_done = torch.ones(1, 1, dtype=torch.bool) @@ -42,13 +42,13 @@ class Environment: obs_z=z, ) - def step(self, action, model, device, flags=None): + def step(self, action, flags=None): obs, reward, done, _ = self.env.step(action) self.episode_return = reward episode_return = self.episode_return if done: - obs = self.env.reset(model, device, flags=flags) + obs = self.env.reset(flags=flags) self.episode_return = torch.zeros(1, 1) position, obs, x_no_action, z = _format_observation(obs) diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index 5ce4b6c..bb8b53f 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -141,7 +141,7 @@ def infer_logic(i, device, infer_queues, model, flags, onnx_frame): if all_empty: time.sleep(0.01) -def act_queue(i, infer_queue, batch_queues, actor_model, flags): +def act_queue(i, infer_queue, batch_queues, flags): positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] try: T = flags.unroll_length @@ -164,7 +164,7 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags): position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34} - position, obs, env_output = env.initial(actor_model, device, flags=flags) + position, obs, env_output = env.initial(flags=flags) while True: while True: if len(obs['legal_actions']) > 1: @@ -189,7 +189,7 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags): x_batch = env_output['obs_x_no_action'].float() obs_x_batch_buf[position].append(x_batch) type_buf[position].append(position_index[position]) - position, obs, env_output = env.step(action, actor_model, device, flags=flags) + position, obs, env_output = env.step(action, flags=flags) size[position] += 1 if env_output['done']: for p in positions: @@ -243,12 +243,12 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags): print() raise e -def act(i, infer_queues, batch_queues, actor_model, flags): +def act(i, infer_queues, batch_queues, flags): threads = [] for x in range(len(infer_queues)): thread = threading.Thread( target=act_queue, name='act_queue-%d-%d' % (i, x), - args=(x, infer_queues[x], batch_queues, actor_model, flags)) + args=(x, infer_queues[x], batch_queues, flags)) thread.setDaemon(True) thread.start() threads.append(thread) diff --git a/douzero/env/env.py b/douzero/env/env.py index 7a50700..d4345b4 100644 --- a/douzero/env/env.py +++ b/douzero/env/env.py @@ -91,7 +91,7 @@ class Env: self.total_round = 0 self.infoset = None - def reset(self, model, device, flags=None): + def reset(self, flags=None): """ Every time reset is called, the environment will be re-initialized with a new deck of cards. @@ -100,46 +100,31 @@ class Env: self._env.reset() # Randomly shuffle the deck - if model is None: - _deck = deck.copy() - np.random.shuffle(_deck) - card_play_data = {'landlord': _deck[:33], - 'landlord_up': _deck[33:58], - 'landlord_front': _deck[58:83], - 'landlord_down': _deck[83:108], - # 'three_landlord_cards': _deck[17:20], - } - for key in card_play_data: - card_play_data[key].sort() - self._env.card_play_init(card_play_data) - self.infoset = self._game_infoset - return get_obs(self.infoset, self.use_general, self.lite_model) - else: - self.total_round += 1 - _deck = deck.copy() - np.random.shuffle(_deck) - card_play_data = {'landlord': _deck[:33], - 'landlord_up': _deck[33:58], - 'landlord_front': _deck[58:83], - 'landlord_down': _deck[83:108], - } - for key in card_play_data: - card_play_data[key].sort() - player_ids = { - 'landlord': 0, - 'landlord_down': 1, - 'landlord_front': 2, - 'landlord_up': 3, - } + self.total_round += 1 + _deck = deck.copy() + np.random.shuffle(_deck) + card_play_data = {'landlord': _deck[:33], + 'landlord_up': _deck[33:58], + 'landlord_front': _deck[58:83], + 'landlord_down': _deck[83:108], + } + for key in card_play_data: + card_play_data[key].sort() + player_ids = { + 'landlord': 0, + 'landlord_down': 1, + 'landlord_front': 2, + 'landlord_up': 3, + } - # Initialize the cards - self._env.card_play_init(card_play_data) - for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]: - pid = player_ids[pos] - self._env.info_sets[pos].player_id = pid - self.infoset = self._game_infoset + # Initialize the cards + self._env.card_play_init(card_play_data) + for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]: + pid = player_ids[pos] + self._env.info_sets[pos].player_id = pid + self.infoset = self._game_infoset - return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model) + return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model) def step(self, action): """