移除根据胜率叫地主逻辑(4人场景下,胜率计算未适配)
This commit is contained in:
parent
a755ffe719
commit
c239085c24
|
@ -129,11 +129,11 @@ class Env:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
action = model.forward("bidding", torch.tensor(bidding_obs["z_batch"], device=device),
|
action = model.forward("bidding", torch.tensor(bidding_obs["z_batch"], device=device),
|
||||||
torch.tensor(bidding_obs["x_batch"], device=device), flags=flags)
|
torch.tensor(bidding_obs["x_batch"], device=device), flags=flags)
|
||||||
if bid_limit <= 0:
|
# if bid_limit <= 0:
|
||||||
wr = BidModel.predict_env(card_play_data[bidding_player])
|
# wr = BidModel.predict_env(card_play_data[bidding_player])
|
||||||
if wr >= 0.7:
|
# if wr >= 0.7:
|
||||||
action = {"action": 1} # debug
|
# action = {"action": 1} # debug
|
||||||
bid_limit += 1
|
# bid_limit += 1
|
||||||
|
|
||||||
bid_obs_buffer.append({
|
bid_obs_buffer.append({
|
||||||
"x_batch": bidding_obs["x_batch"][action["action"]],
|
"x_batch": bidding_obs["x_batch"][action["action"]],
|
||||||
|
|
18
evaluate.py
18
evaluate.py
|
@ -5,6 +5,12 @@ from douzero.evaluation.simulation import evaluate
|
||||||
|
|
||||||
|
|
||||||
def make_evaluate(args, t, frame, adp_frame, folder_a = 'baselines', folder_b = 'baselines'):
|
def make_evaluate(args, t, frame, adp_frame, folder_a = 'baselines', folder_b = 'baselines'):
|
||||||
|
if t == 0:
|
||||||
|
args.landlord = 'random'
|
||||||
|
args.landlord_up = 'random'
|
||||||
|
args.landlord_front = 'random'
|
||||||
|
args.landlord_down = 'random'
|
||||||
|
print('random vs random')
|
||||||
if t == 1:
|
if t == 1:
|
||||||
args.landlord = '%s/resnet_landlord_%i.ckpt' % (folder_a, frame)
|
args.landlord = '%s/resnet_landlord_%i.ckpt' % (folder_a, frame)
|
||||||
args.landlord_up = 'random'
|
args.landlord_up = 'random'
|
||||||
|
@ -96,8 +102,13 @@ if __name__ == '__main__':
|
||||||
# ]
|
# ]
|
||||||
|
|
||||||
eval_list = [
|
eval_list = [
|
||||||
[4968800, 8697600, 'baselines', 'baselines2'],
|
# [4968800, 8697600, 'baselines', 'baselines2'],
|
||||||
[4968800, 4968800, 'baselines', 'baselines'],
|
# [4968800, 4968800, 'baselines', 'baselines'],
|
||||||
|
# [14102400, 4968800, 'baselines', 'baselines'],
|
||||||
|
# [14102400, 13252000, 'baselines', 'baselines2'],
|
||||||
|
# [14102400, 15096800, 'baselines', 'baselines2'],
|
||||||
|
[14102400, 14102400, 'baselines', 'baselines'],
|
||||||
|
# [14102400, None, 'baselines', 'baselines'],
|
||||||
]
|
]
|
||||||
|
|
||||||
for vs in reversed(eval_list):
|
for vs in reversed(eval_list):
|
||||||
|
@ -106,6 +117,9 @@ if __name__ == '__main__':
|
||||||
folder_a = vs[2]
|
folder_a = vs[2]
|
||||||
folder_b = vs[3]
|
folder_b = vs[3]
|
||||||
if adp_frame is None:
|
if adp_frame is None:
|
||||||
|
if frame is None:
|
||||||
|
make_evaluate(args, 0, None, None)
|
||||||
|
else:
|
||||||
make_evaluate(args, 1, frame, None)
|
make_evaluate(args, 1, frame, None)
|
||||||
make_evaluate(args, 2, frame, None)
|
make_evaluate(args, 2, frame, None)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,202 @@
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import douzero
|
||||||
|
|
||||||
|
from douzero.dmc.models import Model
|
||||||
|
|
||||||
|
deck = []
|
||||||
|
for i in range(3, 15):
|
||||||
|
deck.extend([i for _ in range(8)])
|
||||||
|
deck.extend([17 for _ in range(8)])
|
||||||
|
deck.extend([20, 20, 30, 30])
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(description='DouZero: random data generator')
|
||||||
|
parser.add_argument('--output', default='eval_data', type=str)
|
||||||
|
parser.add_argument('--path', default='baselines/resnet_bidding_15419200.ckpt', type=str)
|
||||||
|
parser.add_argument('--num_games', default=10000, type=int)
|
||||||
|
parser.add_argument('--exp_epsilon', default=0.01, type=float)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def generate_with_bid(num_games, bid_model_path):
|
||||||
|
data_list = []
|
||||||
|
for i in range(num_games):
|
||||||
|
bid_done = False
|
||||||
|
card_play_data = []
|
||||||
|
landlord_cards = []
|
||||||
|
last_bid = 0
|
||||||
|
bid_count = 0
|
||||||
|
player_ids = {}
|
||||||
|
bid_info = None
|
||||||
|
bid_obs_buffer = []
|
||||||
|
multiply_obs_buffer = []
|
||||||
|
bid_limit = 4
|
||||||
|
force_bid = False
|
||||||
|
device = torch.device("cpu")
|
||||||
|
model = Model(device='cpu')
|
||||||
|
bid_model = model.get_model("bidding")
|
||||||
|
weights = torch.load(bid_model_path, map_location=device)
|
||||||
|
bid_model.load_state_dict(weights)
|
||||||
|
bid_model.eval()
|
||||||
|
|
||||||
|
while not bid_done:
|
||||||
|
bid_limit -= 1
|
||||||
|
bid_obs_buffer.clear()
|
||||||
|
multiply_obs_buffer.clear()
|
||||||
|
_deck = deck.copy()
|
||||||
|
np.random.shuffle(_deck)
|
||||||
|
card_play_data = [
|
||||||
|
_deck[:25],
|
||||||
|
_deck[25:50],
|
||||||
|
_deck[50:75],
|
||||||
|
_deck[75:100],
|
||||||
|
]
|
||||||
|
for i in range(4):
|
||||||
|
card_play_data[i].sort()
|
||||||
|
landlord_cards = _deck[100:108]
|
||||||
|
landlord_cards.sort()
|
||||||
|
bid_info = np.array([[-1, -1, -1, -1],
|
||||||
|
[-1, -1, -1, -1],
|
||||||
|
[-1, -1, -1, -1],
|
||||||
|
[-1, -1, -1, -1],
|
||||||
|
[-1, -1, -1, -1]])
|
||||||
|
bidding_player = random.randint(0, 3)
|
||||||
|
# bidding_player = 0 # debug
|
||||||
|
first_bid = -1
|
||||||
|
last_bid = -1
|
||||||
|
bid_count = 0
|
||||||
|
for r in range(4):
|
||||||
|
bidding_obs = douzero.env.env._get_obs_for_bid(bidding_player, bid_info, card_play_data[bidding_player])
|
||||||
|
with torch.no_grad():
|
||||||
|
action = model.forward("bidding", torch.tensor(bidding_obs["z_batch"], device=device),
|
||||||
|
torch.tensor(bidding_obs["x_batch"], device=device), flags=flags)
|
||||||
|
# if bid_limit <= 0:
|
||||||
|
# if random.random() < 0.5:
|
||||||
|
# action = {"action": 1} # debug
|
||||||
|
# bid_limit += 1
|
||||||
|
# if bid_count == 0:
|
||||||
|
# bid_score, farmer_score = BidModel.predict_env(card_play_data[bidding_player])
|
||||||
|
# if bid_score * 3 > farmer_score or bid_score > 0:
|
||||||
|
# action = {"action": 1} # debug
|
||||||
|
# bid_limit += 1
|
||||||
|
# else:
|
||||||
|
# action = {"action": 0}
|
||||||
|
# else:
|
||||||
|
# bid_score, farmer_score = BidModel.predict_env(card_play_data[bidding_player])
|
||||||
|
# if bid_score * 2.8 > farmer_score or bid_score > 0.1:
|
||||||
|
# action = {"action": 1} # debug
|
||||||
|
# bid_limit += 1
|
||||||
|
# else:
|
||||||
|
# action = {"action": 0}
|
||||||
|
|
||||||
|
# bid_obs_buffer.append({
|
||||||
|
# "x_batch": bidding_obs["x_batch"][0],
|
||||||
|
# "z_batch": bidding_obs["z_batch"][0],
|
||||||
|
# "action": action["action"],
|
||||||
|
# "pid": bidding_player
|
||||||
|
# })
|
||||||
|
if action["action"] == 1:
|
||||||
|
last_bid = bidding_player
|
||||||
|
bid_count += 1
|
||||||
|
if first_bid == -1:
|
||||||
|
first_bid = bidding_player
|
||||||
|
for p in range(4):
|
||||||
|
if p == bidding_player:
|
||||||
|
bid_info[r][p] = 1
|
||||||
|
else:
|
||||||
|
bid_info[r][p] = 0
|
||||||
|
else:
|
||||||
|
bid_info[r] = [0, 0, 0, 0]
|
||||||
|
bidding_player = (bidding_player + 1) % 4
|
||||||
|
one_count = np.count_nonzero(bid_info == 1)
|
||||||
|
if one_count == 0:
|
||||||
|
continue
|
||||||
|
elif one_count > 1:
|
||||||
|
r = 4
|
||||||
|
bidding_player = first_bid
|
||||||
|
bidding_obs = douzero.env.env._get_obs_for_bid(bidding_player, bid_info, card_play_data[bidding_player])
|
||||||
|
with torch.no_grad():
|
||||||
|
action = model.forward("bidding", torch.tensor(bidding_obs["z_batch"], device=device),
|
||||||
|
torch.tensor(bidding_obs["x_batch"], device=device), flags=flags)
|
||||||
|
# bid_score, farmer_score = BidModel.predict_env(card_play_data[bidding_player])
|
||||||
|
# if bid_score * 2.9 > farmer_score or bid_score > 0.1:
|
||||||
|
# action = {"action": 1} # debug
|
||||||
|
# bid_limit += 1
|
||||||
|
# else:
|
||||||
|
# action = {"action": 0}
|
||||||
|
bid_obs_buffer.append({
|
||||||
|
"x_batch": bidding_obs["x_batch"][action["action"]],
|
||||||
|
"z_batch": bidding_obs["z_batch"][action["action"]],
|
||||||
|
"pid": bidding_player
|
||||||
|
})
|
||||||
|
if action["action"] == 1:
|
||||||
|
last_bid = bidding_player
|
||||||
|
bid_count += 1
|
||||||
|
for p in range(4):
|
||||||
|
if p == bidding_player:
|
||||||
|
bid_info[r][p] = 1
|
||||||
|
else:
|
||||||
|
bid_info[r][p] = 0
|
||||||
|
break
|
||||||
|
card_play_data[last_bid].extend(landlord_cards)
|
||||||
|
card_play_data = {'landlord': card_play_data[last_bid],
|
||||||
|
'landlord_up': card_play_data[(last_bid - 1) % 4],
|
||||||
|
'landlord_front': card_play_data[(last_bid + 2) % 4],
|
||||||
|
'landlord_down': card_play_data[(last_bid + 1) % 4],
|
||||||
|
# 'three_landlord_cards': landlord_cards,
|
||||||
|
}
|
||||||
|
card_play_data["landlord"].sort()
|
||||||
|
player_ids = {
|
||||||
|
'landlord': last_bid,
|
||||||
|
'landlord_up': (last_bid - 1) % 4,
|
||||||
|
'landlord_down': (last_bid + 1) % 4,
|
||||||
|
'landlord_front': (last_bid + 2) % 4,
|
||||||
|
}
|
||||||
|
player_positions = {
|
||||||
|
last_bid: 'landlord',
|
||||||
|
(last_bid - 1) % 4: 'landlord_up',
|
||||||
|
(last_bid + 1) % 4: 'landlord_down',
|
||||||
|
(last_bid + 2) % 4: 'landlord_front',
|
||||||
|
}
|
||||||
|
for bid_obs in bid_obs_buffer:
|
||||||
|
bid_obs.update({"position": player_positions[bid_obs["pid"]]})
|
||||||
|
bid_info_list = {}
|
||||||
|
for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
|
||||||
|
pid = player_ids[pos]
|
||||||
|
bid_info_list[pos] = bid_info[:, [(pid - 1) % 4, pid, (pid + 1) % 4, (pid + 2) % 4]]
|
||||||
|
card_play_data = {
|
||||||
|
"play": card_play_data,
|
||||||
|
"bid": bid_info_list
|
||||||
|
}
|
||||||
|
data_list.append(card_play_data)
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
flags = get_parser().parse_args()
|
||||||
|
output_pickle = flags.output + '.pkl'
|
||||||
|
|
||||||
|
print("output_pickle:", output_pickle)
|
||||||
|
print("generating data...")
|
||||||
|
|
||||||
|
data = []
|
||||||
|
data.extend(generate_with_bid(flags.num_games, flags.path))
|
||||||
|
# round_count = flags.num_games // 3
|
||||||
|
# for _ in range(round_count):
|
||||||
|
# data.extend(generate_3())
|
||||||
|
# if round_count * 3 < flags.num_games:
|
||||||
|
# for i in range(flags.num_games - round_count*3):
|
||||||
|
# data.extend(generate_1())
|
||||||
|
print(data)
|
||||||
|
print("saving pickle file...")
|
||||||
|
with open(output_pickle,'wb') as g:
|
||||||
|
pickle.dump(data,g)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue