Douzero_Resnet/generate_eval_data_with_bid.py

203 lines
7.9 KiB
Python

import argparse
import pickle
import numpy as np
import torch
import random
import douzero
from douzero.dmc.models import Model
deck = []
for i in range(3, 15):
deck.extend([i for _ in range(8)])
deck.extend([17 for _ in range(8)])
deck.extend([20, 20, 30, 30])
def get_parser():
parser = argparse.ArgumentParser(description='DouZero: random data generator')
parser.add_argument('--output', default='eval_data_200', type=str)
parser.add_argument('--path', default='baselines/resnet_bidding_27853600.ckpt', type=str)
parser.add_argument('--num_games', default=200, type=int)
parser.add_argument('--exp_epsilon', default=0.01, type=float)
return parser
def generate_with_bid(num_games, bid_model_path):
data_list = []
for i in range(num_games):
bid_done = False
card_play_data = []
landlord_cards = []
last_bid = 0
bid_count = 0
player_ids = {}
bid_info = None
bid_obs_buffer = []
multiply_obs_buffer = []
bid_limit = 4
force_bid = False
device = torch.device("cpu")
model = Model(device='cpu')
bid_model = model.get_model("bidding")
weights = torch.load(bid_model_path, map_location=device)
bid_model.load_state_dict(weights)
bid_model.eval()
while not bid_done:
bid_limit -= 1
bid_obs_buffer.clear()
multiply_obs_buffer.clear()
_deck = deck.copy()
np.random.shuffle(_deck)
card_play_data = [
_deck[:25],
_deck[25:50],
_deck[50:75],
_deck[75:100],
]
for i in range(4):
card_play_data[i].sort()
landlord_cards = _deck[100:108]
landlord_cards.sort()
bid_info = np.array([[-1, -1, -1, -1],
[-1, -1, -1, -1],
[-1, -1, -1, -1],
[-1, -1, -1, -1],
[-1, -1, -1, -1]])
bidding_player = random.randint(0, 3)
# bidding_player = 0 # debug
first_bid = -1
last_bid = -1
bid_count = 0
for r in range(4):
bidding_obs = douzero.env.env._get_obs_for_bid(bidding_player, bid_info, card_play_data[bidding_player])
with torch.no_grad():
action = model.forward("bidding", torch.tensor(bidding_obs["z_batch"], device=device),
torch.tensor(bidding_obs["x_batch"], device=device), flags=flags)
# if bid_limit <= 0:
# if random.random() < 0.5:
# action = {"action": 1} # debug
# bid_limit += 1
# if bid_count == 0:
# bid_score, farmer_score = BidModel.predict_env(card_play_data[bidding_player])
# if bid_score * 3 > farmer_score or bid_score > 0:
# action = {"action": 1} # debug
# bid_limit += 1
# else:
# action = {"action": 0}
# else:
# bid_score, farmer_score = BidModel.predict_env(card_play_data[bidding_player])
# if bid_score * 2.8 > farmer_score or bid_score > 0.1:
# action = {"action": 1} # debug
# bid_limit += 1
# else:
# action = {"action": 0}
# bid_obs_buffer.append({
# "x_batch": bidding_obs["x_batch"][0],
# "z_batch": bidding_obs["z_batch"][0],
# "action": action["action"],
# "pid": bidding_player
# })
if action["action"] == 1:
last_bid = bidding_player
bid_count += 1
if first_bid == -1:
first_bid = bidding_player
for p in range(4):
if p == bidding_player:
bid_info[r][p] = 1
else:
bid_info[r][p] = 0
else:
bid_info[r] = [0, 0, 0, 0]
bidding_player = (bidding_player + 1) % 4
one_count = np.count_nonzero(bid_info == 1)
if one_count == 0:
continue
elif one_count > 1:
r = 4
bidding_player = first_bid
bidding_obs = douzero.env.env._get_obs_for_bid(bidding_player, bid_info, card_play_data[bidding_player])
with torch.no_grad():
action = model.forward("bidding", torch.tensor(bidding_obs["z_batch"], device=device),
torch.tensor(bidding_obs["x_batch"], device=device), flags=flags)
# bid_score, farmer_score = BidModel.predict_env(card_play_data[bidding_player])
# if bid_score * 2.9 > farmer_score or bid_score > 0.1:
# action = {"action": 1} # debug
# bid_limit += 1
# else:
# action = {"action": 0}
bid_obs_buffer.append({
"x_batch": bidding_obs["x_batch"][action["action"]],
"z_batch": bidding_obs["z_batch"][action["action"]],
"pid": bidding_player
})
if action["action"] == 1:
last_bid = bidding_player
bid_count += 1
for p in range(4):
if p == bidding_player:
bid_info[r][p] = 1
else:
bid_info[r][p] = 0
break
card_play_data[last_bid].extend(landlord_cards)
card_play_data = {'landlord': card_play_data[last_bid],
'landlord_up': card_play_data[(last_bid - 1) % 4],
'landlord_front': card_play_data[(last_bid + 2) % 4],
'landlord_down': card_play_data[(last_bid + 1) % 4],
# 'three_landlord_cards': landlord_cards,
}
card_play_data["landlord"].sort()
player_ids = {
'landlord': last_bid,
'landlord_up': (last_bid - 1) % 4,
'landlord_down': (last_bid + 1) % 4,
'landlord_front': (last_bid + 2) % 4,
}
player_positions = {
last_bid: 'landlord',
(last_bid - 1) % 4: 'landlord_up',
(last_bid + 1) % 4: 'landlord_down',
(last_bid + 2) % 4: 'landlord_front',
}
for bid_obs in bid_obs_buffer:
bid_obs.update({"position": player_positions[bid_obs["pid"]]})
bid_info_list = {}
for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
pid = player_ids[pos]
bid_info_list[pos] = bid_info[:, [(pid - 1) % 4, pid, (pid + 1) % 4, (pid + 2) % 4]]
card_play_data = {
"play": card_play_data,
"bid": bid_info_list
}
data_list.append(card_play_data)
return data_list
if __name__ == '__main__':
flags = get_parser().parse_args()
output_pickle = flags.output + '.pkl'
print("output_pickle:", output_pickle)
print("generating data...")
data = []
data.extend(generate_with_bid(flags.num_games, flags.path))
# round_count = flags.num_games // 3
# for _ in range(round_count):
# data.extend(generate_3())
# if round_count * 3 < flags.num_games:
# for i in range(flags.num_games - round_count*3):
# data.extend(generate_1())
print(data)
print("saving pickle file...")
with open(output_pickle,'wb') as g:
pickle.dump(data,g)