Douzero_Resnet/douzero/evaluation/simulation.py

136 lines
5.1 KiB
Python

import multiprocessing as mp
import pickle
import douzero.env.env
from douzero.dmc.models import Model
from douzero.env.game import GameEnv
import torch
import numpy as np
import psutil
def load_card_play_models(card_play_model_path_dict):
players = {}
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
if card_play_model_path_dict[position] == 'rlcard':
from .rlcard_agent import RLCardAgent
players[position] = RLCardAgent(position)
elif card_play_model_path_dict[position] == 'random':
from .random_agent import RandomAgent
players[position] = RandomAgent()
else:
from .deep_agent import DeepAgent
players[position] = DeepAgent(position, card_play_model_path_dict[position], use_onnx=True)
return players
def mp_simulate(card_play_data_list, players, q, output, title):
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
env = GameEnv(players)
for idx, card_play_data in enumerate(card_play_data_list):
env.card_play_init(card_play_data)
if output:
print("\nStart ------- " + title)
if 'play' in card_play_data.keys():
print ("".join([EnvCard2RealCard[c] for c in card_play_data['play']["landlord"]]))
print ("".join([EnvCard2RealCard[c] for c in card_play_data['play']["landlord_down"]]))
print ("".join([EnvCard2RealCard[c] for c in card_play_data['play']["landlord_front"]]))
print ("".join([EnvCard2RealCard[c] for c in card_play_data['play']["landlord_up"]]))
else:
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord"]]))
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord_down"]]))
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord_front"]]))
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord_up"]]))
# print(card_play_data)
count = 0
while not env.game_over:
action = env.step()
if output:
if count % 4 == 3:
end = "\n"
else:
end = " "
if len(action) == 0:
print("Pass", end=end)
else:
print("".join([EnvCard2RealCard[c] for c in action]), end=end)
count+=1
if output and idx % 10 == 0:
print("\nindex", idx)
# print("End -------")
env.reset()
q.put((env.num_wins['landlord'],
env.num_wins['farmer'],
env.num_scores['landlord'],
env.num_scores['farmer']
))
def data_allocation_per_worker(card_play_data_list, num_workers):
card_play_data_list_each_worker = [[] for k in range(num_workers)]
for idx, data in enumerate(card_play_data_list):
card_play_data_list_each_worker[idx % num_workers].append(data)
return card_play_data_list_each_worker
def evaluate(landlord, landlord_up, landlord_front, landlord_down, eval_data, num_workers, output, title):
with open(eval_data, 'rb') as f:
card_play_data_list = pickle.load(f)
card_play_data_list_each_worker = data_allocation_per_worker(
card_play_data_list, num_workers)
del card_play_data_list
card_play_model_path_dict = {
'landlord': landlord,
'landlord_up': landlord_up,
'landlord_front': landlord_front,
'landlord_down': landlord_down}
num_landlord_wins = 0
num_farmer_wins = 0
num_landlord_scores = 0
num_farmer_scores = 0
ctx = mp.get_context('spawn')
q = ctx.SimpleQueue()
processes = []
players = load_card_play_models(card_play_model_path_dict)
for card_paly_data in card_play_data_list_each_worker:
p = ctx.Process(
target=mp_simulate,
args=(card_paly_data, players, q, output, title))
p.start()
processes.append(p)
parent = psutil.Process()
parent.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
for child in parent.children():
child.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
for p in processes:
p.join()
for i in range(num_workers):
result = q.get()
num_landlord_wins += result[0]
num_farmer_wins += result[1]
num_landlord_scores += result[2]
num_farmer_scores += result[3]
num_total_wins = num_landlord_wins + num_farmer_wins
print('WP results:')
landlord_wp = num_landlord_wins / num_total_wins
farmer_wp = num_farmer_wins / num_total_wins
print('landlord : Farmers - {} : {}'.format(landlord_wp, farmer_wp))
print('ADP results:')
landlord_adp = num_landlord_scores / num_total_wins
farmer_adp = 3 * num_farmer_scores / num_total_wins
print('landlord : Farmers - {} : {}'.format(landlord_adp, farmer_adp))
return landlord_wp, farmer_wp, landlord_adp, farmer_adp