移除无效参数

This commit is contained in:
ZaneYork 2022-01-02 22:33:48 +08:00
parent d76cac1335
commit 4571bb3dfc
4 changed files with 52 additions and 57 deletions

View File

@ -187,7 +187,7 @@ def train(flags):
for i in range(num_actors):
actor = mp.Process(
target=act,
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, actor_model, flags))
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags))
actor.daemon = True
actor.start()
actor_processes.append({
@ -306,14 +306,24 @@ def train(flags):
position_fps['landlord_front'],
position_fps['landlord_down'],
pprint.pformat(stats))
# for proc in actor_processes:
# if not proc['actor'].is_alive():
# actor = mp.Process(
# target=act,
# args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame))
# actor.daemon = True
# actor.start()
# proc['actor'] = actor
for proc in actor_processes:
if not proc['actor'].is_alive():
i = proc['i']
actor = mp.Process(
target=act,
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags))
actor.daemon = True
actor.start()
proc['actor'] = actor
for proc in infer_processes:
if not proc['infer'].is_alive():
infer = mp.Process(
target=infer_logic,
args=(proc['i'], proc['device'], infer_queues, actor_model, flags, onnx_frame))
infer.daemon = True
infer.start()
proc['infer'] = actor
except KeyboardInterrupt:
flags.enable_upload = False

View File

@ -30,8 +30,8 @@ class Environment:
self.device = device
self.episode_return = None
def initial(self, model, device, flags=None):
obs = self.env.reset(model, device, flags=flags)
def initial(self, flags=None):
obs = self.env.reset(flags=flags)
initial_position, initial_obs, x_no_action, z = _format_observation(obs)
self.episode_return = torch.zeros(1, 1)
initial_done = torch.ones(1, 1, dtype=torch.bool)
@ -42,13 +42,13 @@ class Environment:
obs_z=z,
)
def step(self, action, model, device, flags=None):
def step(self, action, flags=None):
obs, reward, done, _ = self.env.step(action)
self.episode_return = reward
episode_return = self.episode_return
if done:
obs = self.env.reset(model, device, flags=flags)
obs = self.env.reset(flags=flags)
self.episode_return = torch.zeros(1, 1)
position, obs, x_no_action, z = _format_observation(obs)

View File

@ -141,7 +141,7 @@ def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
if all_empty:
time.sleep(0.01)
def act_queue(i, infer_queue, batch_queues, actor_model, flags):
def act_queue(i, infer_queue, batch_queues, flags):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
try:
T = flags.unroll_length
@ -164,7 +164,7 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
position, obs, env_output = env.initial(actor_model, device, flags=flags)
position, obs, env_output = env.initial(flags=flags)
while True:
while True:
if len(obs['legal_actions']) > 1:
@ -189,7 +189,7 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
x_batch = env_output['obs_x_no_action'].float()
obs_x_batch_buf[position].append(x_batch)
type_buf[position].append(position_index[position])
position, obs, env_output = env.step(action, actor_model, device, flags=flags)
position, obs, env_output = env.step(action, flags=flags)
size[position] += 1
if env_output['done']:
for p in positions:
@ -243,12 +243,12 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
print()
raise e
def act(i, infer_queues, batch_queues, actor_model, flags):
def act(i, infer_queues, batch_queues, flags):
threads = []
for x in range(len(infer_queues)):
thread = threading.Thread(
target=act_queue, name='act_queue-%d-%d' % (i, x),
args=(x, infer_queues[x], batch_queues, actor_model, flags))
args=(x, infer_queues[x], batch_queues, flags))
thread.setDaemon(True)
thread.start()
threads.append(thread)

63
douzero/env/env.py vendored
View File

@ -91,7 +91,7 @@ class Env:
self.total_round = 0
self.infoset = None
def reset(self, model, device, flags=None):
def reset(self, flags=None):
"""
Every time reset is called, the environment
will be re-initialized with a new deck of cards.
@ -100,46 +100,31 @@ class Env:
self._env.reset()
# Randomly shuffle the deck
if model is None:
_deck = deck.copy()
np.random.shuffle(_deck)
card_play_data = {'landlord': _deck[:33],
'landlord_up': _deck[33:58],
'landlord_front': _deck[58:83],
'landlord_down': _deck[83:108],
# 'three_landlord_cards': _deck[17:20],
}
for key in card_play_data:
card_play_data[key].sort()
self._env.card_play_init(card_play_data)
self.infoset = self._game_infoset
return get_obs(self.infoset, self.use_general, self.lite_model)
else:
self.total_round += 1
_deck = deck.copy()
np.random.shuffle(_deck)
card_play_data = {'landlord': _deck[:33],
'landlord_up': _deck[33:58],
'landlord_front': _deck[58:83],
'landlord_down': _deck[83:108],
}
for key in card_play_data:
card_play_data[key].sort()
player_ids = {
'landlord': 0,
'landlord_down': 1,
'landlord_front': 2,
'landlord_up': 3,
}
self.total_round += 1
_deck = deck.copy()
np.random.shuffle(_deck)
card_play_data = {'landlord': _deck[:33],
'landlord_up': _deck[33:58],
'landlord_front': _deck[58:83],
'landlord_down': _deck[83:108],
}
for key in card_play_data:
card_play_data[key].sort()
player_ids = {
'landlord': 0,
'landlord_down': 1,
'landlord_front': 2,
'landlord_up': 3,
}
# Initialize the cards
self._env.card_play_init(card_play_data)
for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
pid = player_ids[pos]
self._env.info_sets[pos].player_id = pid
self.infoset = self._game_infoset
# Initialize the cards
self._env.card_play_init(card_play_data)
for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
pid = player_ids[pos]
self._env.info_sets[pos].player_id = pid
self.infoset = self._game_infoset
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
def step(self, action):
"""