203 lines
7.9 KiB
Python
203 lines
7.9 KiB
Python
import argparse
|
|
import pickle
|
|
import numpy as np
|
|
import torch
|
|
import random
|
|
import douzero
|
|
|
|
from douzero.dmc.models import Model
|
|
|
|
deck = []
|
|
for i in range(3, 15):
|
|
deck.extend([i for _ in range(8)])
|
|
deck.extend([17 for _ in range(8)])
|
|
deck.extend([20, 20, 30, 30])
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(description='DouZero: random data generator')
|
|
parser.add_argument('--output', default='eval_data_500', type=str)
|
|
parser.add_argument('--path', default='baselines/resnet_landlord_23358800.ckpt', type=str)
|
|
parser.add_argument('--num_games', default=200, type=int)
|
|
parser.add_argument('--exp_epsilon', default=0.01, type=float)
|
|
return parser
|
|
|
|
|
|
def generate_with_bid(num_games, bid_model_path):
|
|
data_list = []
|
|
for i in range(num_games):
|
|
bid_done = False
|
|
card_play_data = []
|
|
landlord_cards = []
|
|
last_bid = 0
|
|
bid_count = 0
|
|
player_ids = {}
|
|
bid_info = None
|
|
bid_obs_buffer = []
|
|
multiply_obs_buffer = []
|
|
bid_limit = 4
|
|
force_bid = False
|
|
device = torch.device("cpu")
|
|
model = Model(device='cpu')
|
|
bid_model = model.get_model("bidding")
|
|
weights = torch.load(bid_model_path, map_location=device)
|
|
bid_model.load_state_dict(weights)
|
|
bid_model.eval()
|
|
|
|
while not bid_done:
|
|
bid_limit -= 1
|
|
bid_obs_buffer.clear()
|
|
multiply_obs_buffer.clear()
|
|
_deck = deck.copy()
|
|
np.random.shuffle(_deck)
|
|
card_play_data = [
|
|
_deck[:25],
|
|
_deck[25:50],
|
|
_deck[50:75],
|
|
_deck[75:100],
|
|
]
|
|
for i in range(4):
|
|
card_play_data[i].sort()
|
|
landlord_cards = _deck[100:108]
|
|
landlord_cards.sort()
|
|
bid_info = np.array([[-1, -1, -1, -1],
|
|
[-1, -1, -1, -1],
|
|
[-1, -1, -1, -1],
|
|
[-1, -1, -1, -1],
|
|
[-1, -1, -1, -1]])
|
|
bidding_player = random.randint(0, 3)
|
|
# bidding_player = 0 # debug
|
|
first_bid = -1
|
|
last_bid = -1
|
|
bid_count = 0
|
|
for r in range(4):
|
|
bidding_obs = douzero.env.env._get_obs_for_bid(bidding_player, bid_info, card_play_data[bidding_player])
|
|
with torch.no_grad():
|
|
action = model.forward("bidding", torch.tensor(bidding_obs["z_batch"], device=device),
|
|
torch.tensor(bidding_obs["x_batch"], device=device), flags=flags)
|
|
# if bid_limit <= 0:
|
|
# if random.random() < 0.5:
|
|
# action = {"action": 1} # debug
|
|
# bid_limit += 1
|
|
# if bid_count == 0:
|
|
# bid_score, farmer_score = BidModel.predict_env(card_play_data[bidding_player])
|
|
# if bid_score * 3 > farmer_score or bid_score > 0:
|
|
# action = {"action": 1} # debug
|
|
# bid_limit += 1
|
|
# else:
|
|
# action = {"action": 0}
|
|
# else:
|
|
# bid_score, farmer_score = BidModel.predict_env(card_play_data[bidding_player])
|
|
# if bid_score * 2.8 > farmer_score or bid_score > 0.1:
|
|
# action = {"action": 1} # debug
|
|
# bid_limit += 1
|
|
# else:
|
|
# action = {"action": 0}
|
|
|
|
# bid_obs_buffer.append({
|
|
# "x_batch": bidding_obs["x_batch"][0],
|
|
# "z_batch": bidding_obs["z_batch"][0],
|
|
# "action": action["action"],
|
|
# "pid": bidding_player
|
|
# })
|
|
if action["action"] == 1:
|
|
last_bid = bidding_player
|
|
bid_count += 1
|
|
if first_bid == -1:
|
|
first_bid = bidding_player
|
|
for p in range(4):
|
|
if p == bidding_player:
|
|
bid_info[r][p] = 1
|
|
else:
|
|
bid_info[r][p] = 0
|
|
else:
|
|
bid_info[r] = [0, 0, 0, 0]
|
|
bidding_player = (bidding_player + 1) % 4
|
|
one_count = np.count_nonzero(bid_info == 1)
|
|
if one_count == 0:
|
|
continue
|
|
elif one_count > 1:
|
|
r = 4
|
|
bidding_player = first_bid
|
|
bidding_obs = douzero.env.env._get_obs_for_bid(bidding_player, bid_info, card_play_data[bidding_player])
|
|
with torch.no_grad():
|
|
action = model.forward("bidding", torch.tensor(bidding_obs["z_batch"], device=device),
|
|
torch.tensor(bidding_obs["x_batch"], device=device), flags=flags)
|
|
# bid_score, farmer_score = BidModel.predict_env(card_play_data[bidding_player])
|
|
# if bid_score * 2.9 > farmer_score or bid_score > 0.1:
|
|
# action = {"action": 1} # debug
|
|
# bid_limit += 1
|
|
# else:
|
|
# action = {"action": 0}
|
|
bid_obs_buffer.append({
|
|
"x_batch": bidding_obs["x_batch"][action["action"]],
|
|
"z_batch": bidding_obs["z_batch"][action["action"]],
|
|
"pid": bidding_player
|
|
})
|
|
if action["action"] == 1:
|
|
last_bid = bidding_player
|
|
bid_count += 1
|
|
for p in range(4):
|
|
if p == bidding_player:
|
|
bid_info[r][p] = 1
|
|
else:
|
|
bid_info[r][p] = 0
|
|
break
|
|
card_play_data[last_bid].extend(landlord_cards)
|
|
card_play_data = {'landlord': card_play_data[last_bid],
|
|
'landlord_up': card_play_data[(last_bid - 1) % 4],
|
|
'landlord_front': card_play_data[(last_bid + 2) % 4],
|
|
'landlord_down': card_play_data[(last_bid + 1) % 4],
|
|
# 'three_landlord_cards': landlord_cards,
|
|
}
|
|
card_play_data["landlord"].sort()
|
|
player_ids = {
|
|
'landlord': last_bid,
|
|
'landlord_up': (last_bid - 1) % 4,
|
|
'landlord_down': (last_bid + 1) % 4,
|
|
'landlord_front': (last_bid + 2) % 4,
|
|
}
|
|
player_positions = {
|
|
last_bid: 'landlord',
|
|
(last_bid - 1) % 4: 'landlord_up',
|
|
(last_bid + 1) % 4: 'landlord_down',
|
|
(last_bid + 2) % 4: 'landlord_front',
|
|
}
|
|
for bid_obs in bid_obs_buffer:
|
|
bid_obs.update({"position": player_positions[bid_obs["pid"]]})
|
|
bid_info_list = {}
|
|
for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
|
|
pid = player_ids[pos]
|
|
bid_info_list[pos] = bid_info[:, [(pid - 1) % 4, pid, (pid + 1) % 4, (pid + 2) % 4]]
|
|
card_play_data = {
|
|
"play": card_play_data,
|
|
"bid": bid_info_list
|
|
}
|
|
data_list.append(card_play_data)
|
|
return data_list
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
flags = get_parser().parse_args()
|
|
output_pickle = flags.output + '.pkl'
|
|
|
|
print("output_pickle:", output_pickle)
|
|
print("generating data...")
|
|
|
|
data = []
|
|
data.extend(generate_with_bid(flags.num_games, flags.path))
|
|
# round_count = flags.num_games // 3
|
|
# for _ in range(round_count):
|
|
# data.extend(generate_3())
|
|
# if round_count * 3 < flags.num_games:
|
|
# for i in range(flags.num_games - round_count*3):
|
|
# data.extend(generate_1())
|
|
print(data)
|
|
print("saving pickle file...")
|
|
with open(output_pickle,'wb') as g:
|
|
pickle.dump(data,g)
|
|
|
|
|
|
|
|
|