移除无效参数
This commit is contained in:
parent
d76cac1335
commit
4571bb3dfc
|
@ -187,7 +187,7 @@ def train(flags):
|
||||||
for i in range(num_actors):
|
for i in range(num_actors):
|
||||||
actor = mp.Process(
|
actor = mp.Process(
|
||||||
target=act,
|
target=act,
|
||||||
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, actor_model, flags))
|
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags))
|
||||||
actor.daemon = True
|
actor.daemon = True
|
||||||
actor.start()
|
actor.start()
|
||||||
actor_processes.append({
|
actor_processes.append({
|
||||||
|
@ -306,14 +306,24 @@ def train(flags):
|
||||||
position_fps['landlord_front'],
|
position_fps['landlord_front'],
|
||||||
position_fps['landlord_down'],
|
position_fps['landlord_down'],
|
||||||
pprint.pformat(stats))
|
pprint.pformat(stats))
|
||||||
# for proc in actor_processes:
|
for proc in actor_processes:
|
||||||
# if not proc['actor'].is_alive():
|
if not proc['actor'].is_alive():
|
||||||
# actor = mp.Process(
|
i = proc['i']
|
||||||
# target=act,
|
actor = mp.Process(
|
||||||
# args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame))
|
target=act,
|
||||||
# actor.daemon = True
|
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags))
|
||||||
# actor.start()
|
actor.daemon = True
|
||||||
# proc['actor'] = actor
|
actor.start()
|
||||||
|
proc['actor'] = actor
|
||||||
|
|
||||||
|
for proc in infer_processes:
|
||||||
|
if not proc['infer'].is_alive():
|
||||||
|
infer = mp.Process(
|
||||||
|
target=infer_logic,
|
||||||
|
args=(proc['i'], proc['device'], infer_queues, actor_model, flags, onnx_frame))
|
||||||
|
infer.daemon = True
|
||||||
|
infer.start()
|
||||||
|
proc['infer'] = actor
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
flags.enable_upload = False
|
flags.enable_upload = False
|
||||||
|
|
|
@ -30,8 +30,8 @@ class Environment:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.episode_return = None
|
self.episode_return = None
|
||||||
|
|
||||||
def initial(self, model, device, flags=None):
|
def initial(self, flags=None):
|
||||||
obs = self.env.reset(model, device, flags=flags)
|
obs = self.env.reset(flags=flags)
|
||||||
initial_position, initial_obs, x_no_action, z = _format_observation(obs)
|
initial_position, initial_obs, x_no_action, z = _format_observation(obs)
|
||||||
self.episode_return = torch.zeros(1, 1)
|
self.episode_return = torch.zeros(1, 1)
|
||||||
initial_done = torch.ones(1, 1, dtype=torch.bool)
|
initial_done = torch.ones(1, 1, dtype=torch.bool)
|
||||||
|
@ -42,13 +42,13 @@ class Environment:
|
||||||
obs_z=z,
|
obs_z=z,
|
||||||
)
|
)
|
||||||
|
|
||||||
def step(self, action, model, device, flags=None):
|
def step(self, action, flags=None):
|
||||||
obs, reward, done, _ = self.env.step(action)
|
obs, reward, done, _ = self.env.step(action)
|
||||||
|
|
||||||
self.episode_return = reward
|
self.episode_return = reward
|
||||||
episode_return = self.episode_return
|
episode_return = self.episode_return
|
||||||
if done:
|
if done:
|
||||||
obs = self.env.reset(model, device, flags=flags)
|
obs = self.env.reset(flags=flags)
|
||||||
self.episode_return = torch.zeros(1, 1)
|
self.episode_return = torch.zeros(1, 1)
|
||||||
|
|
||||||
position, obs, x_no_action, z = _format_observation(obs)
|
position, obs, x_no_action, z = _format_observation(obs)
|
||||||
|
|
|
@ -141,7 +141,7 @@ def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
|
||||||
if all_empty:
|
if all_empty:
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
|
|
||||||
def act_queue(i, infer_queue, batch_queues, actor_model, flags):
|
def act_queue(i, infer_queue, batch_queues, flags):
|
||||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
try:
|
try:
|
||||||
T = flags.unroll_length
|
T = flags.unroll_length
|
||||||
|
@ -164,7 +164,7 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
|
||||||
|
|
||||||
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
|
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
|
||||||
|
|
||||||
position, obs, env_output = env.initial(actor_model, device, flags=flags)
|
position, obs, env_output = env.initial(flags=flags)
|
||||||
while True:
|
while True:
|
||||||
while True:
|
while True:
|
||||||
if len(obs['legal_actions']) > 1:
|
if len(obs['legal_actions']) > 1:
|
||||||
|
@ -189,7 +189,7 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
|
||||||
x_batch = env_output['obs_x_no_action'].float()
|
x_batch = env_output['obs_x_no_action'].float()
|
||||||
obs_x_batch_buf[position].append(x_batch)
|
obs_x_batch_buf[position].append(x_batch)
|
||||||
type_buf[position].append(position_index[position])
|
type_buf[position].append(position_index[position])
|
||||||
position, obs, env_output = env.step(action, actor_model, device, flags=flags)
|
position, obs, env_output = env.step(action, flags=flags)
|
||||||
size[position] += 1
|
size[position] += 1
|
||||||
if env_output['done']:
|
if env_output['done']:
|
||||||
for p in positions:
|
for p in positions:
|
||||||
|
@ -243,12 +243,12 @@ def act_queue(i, infer_queue, batch_queues, actor_model, flags):
|
||||||
print()
|
print()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def act(i, infer_queues, batch_queues, actor_model, flags):
|
def act(i, infer_queues, batch_queues, flags):
|
||||||
threads = []
|
threads = []
|
||||||
for x in range(len(infer_queues)):
|
for x in range(len(infer_queues)):
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=act_queue, name='act_queue-%d-%d' % (i, x),
|
target=act_queue, name='act_queue-%d-%d' % (i, x),
|
||||||
args=(x, infer_queues[x], batch_queues, actor_model, flags))
|
args=(x, infer_queues[x], batch_queues, flags))
|
||||||
thread.setDaemon(True)
|
thread.setDaemon(True)
|
||||||
thread.start()
|
thread.start()
|
||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
|
|
|
@ -91,7 +91,7 @@ class Env:
|
||||||
self.total_round = 0
|
self.total_round = 0
|
||||||
self.infoset = None
|
self.infoset = None
|
||||||
|
|
||||||
def reset(self, model, device, flags=None):
|
def reset(self, flags=None):
|
||||||
"""
|
"""
|
||||||
Every time reset is called, the environment
|
Every time reset is called, the environment
|
||||||
will be re-initialized with a new deck of cards.
|
will be re-initialized with a new deck of cards.
|
||||||
|
@ -100,46 +100,31 @@ class Env:
|
||||||
self._env.reset()
|
self._env.reset()
|
||||||
|
|
||||||
# Randomly shuffle the deck
|
# Randomly shuffle the deck
|
||||||
if model is None:
|
self.total_round += 1
|
||||||
_deck = deck.copy()
|
_deck = deck.copy()
|
||||||
np.random.shuffle(_deck)
|
np.random.shuffle(_deck)
|
||||||
card_play_data = {'landlord': _deck[:33],
|
card_play_data = {'landlord': _deck[:33],
|
||||||
'landlord_up': _deck[33:58],
|
'landlord_up': _deck[33:58],
|
||||||
'landlord_front': _deck[58:83],
|
'landlord_front': _deck[58:83],
|
||||||
'landlord_down': _deck[83:108],
|
'landlord_down': _deck[83:108],
|
||||||
# 'three_landlord_cards': _deck[17:20],
|
}
|
||||||
}
|
for key in card_play_data:
|
||||||
for key in card_play_data:
|
card_play_data[key].sort()
|
||||||
card_play_data[key].sort()
|
player_ids = {
|
||||||
self._env.card_play_init(card_play_data)
|
'landlord': 0,
|
||||||
self.infoset = self._game_infoset
|
'landlord_down': 1,
|
||||||
return get_obs(self.infoset, self.use_general, self.lite_model)
|
'landlord_front': 2,
|
||||||
else:
|
'landlord_up': 3,
|
||||||
self.total_round += 1
|
}
|
||||||
_deck = deck.copy()
|
|
||||||
np.random.shuffle(_deck)
|
|
||||||
card_play_data = {'landlord': _deck[:33],
|
|
||||||
'landlord_up': _deck[33:58],
|
|
||||||
'landlord_front': _deck[58:83],
|
|
||||||
'landlord_down': _deck[83:108],
|
|
||||||
}
|
|
||||||
for key in card_play_data:
|
|
||||||
card_play_data[key].sort()
|
|
||||||
player_ids = {
|
|
||||||
'landlord': 0,
|
|
||||||
'landlord_down': 1,
|
|
||||||
'landlord_front': 2,
|
|
||||||
'landlord_up': 3,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize the cards
|
# Initialize the cards
|
||||||
self._env.card_play_init(card_play_data)
|
self._env.card_play_init(card_play_data)
|
||||||
for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
|
for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
|
||||||
pid = player_ids[pos]
|
pid = player_ids[pos]
|
||||||
self._env.info_sets[pos].player_id = pid
|
self._env.info_sets[pos].player_id = pid
|
||||||
self.infoset = self._game_infoset
|
self.infoset = self._game_infoset
|
||||||
|
|
||||||
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
|
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue