移除无效参数
This commit is contained in:
parent
d76cac1335
commit
4571bb3dfc
|
@ -187,7 +187,7 @@ def train(flags):
|
|||
for i in range(num_actors):
|
||||
actor = mp.Process(
|
||||
target=act,
|
||||
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, actor_model, flags))
|
||||
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags))
|
||||
actor.daemon = True
|
||||
actor.start()
|
||||
actor_processes.append({
|
||||
|
@ -306,14 +306,24 @@ def train(flags):
|
|||
position_fps['landlord_front'],
|
||||
position_fps['landlord_down'],
|
||||
pprint.pformat(stats))
|
||||
# for proc in actor_processes:
|
||||
# if not proc['actor'].is_alive():
|
||||
# actor = mp.Process(
|
||||
# target=act,
|
||||
# args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame))
|
||||
# actor.daemon = True
|
||||
# actor.start()
|
||||
# proc['actor'] = actor
|
||||
for proc in actor_processes:
|
||||
if not proc['actor'].is_alive():
|
||||
i = proc['i']
|
||||
actor = mp.Process(
|
||||
target=act,
|
||||
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags))
|
||||
actor.daemon = True
|
||||
actor.start()
|
||||
proc['actor'] = actor
|
||||
|
||||
for proc in infer_processes:
|
||||
if not proc['infer'].is_alive():
|
||||
infer = mp.Process(
|
||||
target=infer_logic,
|
||||
args=(proc['i'], proc['device'], infer_queues, actor_model, flags, onnx_frame))
|
||||
infer.daemon = True
|
||||
infer.start()
|
||||
proc['infer'] = actor
|
||||
|
||||
except KeyboardInterrupt:
|
||||
flags.enable_upload = False
|
||||
|
|
|
@ -30,8 +30,8 @@ class Environment:
|
|||
self.device = device
|
||||
self.episode_return = None
|
||||
|
||||
def initial(self, model, device, flags=None):
|
||||
obs = self.env.reset(model, device, flags=flags)
|
||||
def initial(self, flags=None):
|
||||
obs = self.env.reset(flags=flags)
|
||||
initial_position, initial_obs, x_no_action, z = _format_observation(obs)
|
||||
self.episode_return = torch.zeros(1, 1)
|
||||
initial_done = torch.ones(1, 1, dtype=torch.bool)
|
||||
|
@ -42,13 +42,13 @@ class Environment:
|
|||
obs_z=z,
|
||||
)
|
||||
|
||||
def step(self, action, model, device, flags=None):
|
||||
def step(self, action, flags=None):
|
||||
obs, reward, done, _ = self.env.step(action)
|
||||
|
||||
self.episode_return = reward
|
||||
episode_return = self.episode_return
|
||||
if done:
|
||||
obs = self.env.reset(model, device, flags=flags)
|
||||
obs = self.env.reset(flags=flags)
|
||||
self.episode_return = torch.zeros(1, 1)
|
||||
|
||||
position, obs, x_no_action, z = _format_observation(obs)
|
||||
|
|
|
@ -141,7 +141,7 @@ def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
|
|||
if all_empty:
|
||||
time.sleep(0.01)
|
||||
|
||||
def act_queue(i, infer_queue, batch_queues, actor_model, flags):
|
||||
def act_queue(i, infer_queue, batch_queues, flags):
|
||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
try:
|
||||
T = flags.unroll_length
|
||||
|
@ -164,7 +164,7 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
|
|||
|
||||
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
|
||||
|
||||
position, obs, env_output = env.initial(actor_model, device, flags=flags)
|
||||
position, obs, env_output = env.initial(flags=flags)
|
||||
while True:
|
||||
while True:
|
||||
if len(obs['legal_actions']) > 1:
|
||||
|
@ -189,7 +189,7 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
|
|||
x_batch = env_output['obs_x_no_action'].float()
|
||||
obs_x_batch_buf[position].append(x_batch)
|
||||
type_buf[position].append(position_index[position])
|
||||
position, obs, env_output = env.step(action, actor_model, device, flags=flags)
|
||||
position, obs, env_output = env.step(action, flags=flags)
|
||||
size[position] += 1
|
||||
if env_output['done']:
|
||||
for p in positions:
|
||||
|
@ -243,12 +243,12 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
|
|||
print()
|
||||
raise e
|
||||
|
||||
def act(i, infer_queues, batch_queues, actor_model, flags):
|
||||
def act(i, infer_queues, batch_queues, flags):
|
||||
threads = []
|
||||
for x in range(len(infer_queues)):
|
||||
thread = threading.Thread(
|
||||
target=act_queue, name='act_queue-%d-%d' % (i, x),
|
||||
args=(x, infer_queues[x], batch_queues, actor_model, flags))
|
||||
args=(x, infer_queues[x], batch_queues, flags))
|
||||
thread.setDaemon(True)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
|
|
@ -91,7 +91,7 @@ class Env:
|
|||
self.total_round = 0
|
||||
self.infoset = None
|
||||
|
||||
def reset(self, model, device, flags=None):
|
||||
def reset(self, flags=None):
|
||||
"""
|
||||
Every time reset is called, the environment
|
||||
will be re-initialized with a new deck of cards.
|
||||
|
@ -100,21 +100,6 @@ class Env:
|
|||
self._env.reset()
|
||||
|
||||
# Randomly shuffle the deck
|
||||
if model is None:
|
||||
_deck = deck.copy()
|
||||
np.random.shuffle(_deck)
|
||||
card_play_data = {'landlord': _deck[:33],
|
||||
'landlord_up': _deck[33:58],
|
||||
'landlord_front': _deck[58:83],
|
||||
'landlord_down': _deck[83:108],
|
||||
# 'three_landlord_cards': _deck[17:20],
|
||||
}
|
||||
for key in card_play_data:
|
||||
card_play_data[key].sort()
|
||||
self._env.card_play_init(card_play_data)
|
||||
self.infoset = self._game_infoset
|
||||
return get_obs(self.infoset, self.use_general, self.lite_model)
|
||||
else:
|
||||
self.total_round += 1
|
||||
_deck = deck.copy()
|
||||
np.random.shuffle(_deck)
|
||||
|
|
Loading…
Reference in New Issue