136 lines
5.1 KiB
Python
136 lines
5.1 KiB
Python
import multiprocessing as mp
|
|
import pickle
|
|
import douzero.env.env
|
|
from douzero.dmc.models import Model
|
|
from douzero.env.game import GameEnv
|
|
import torch
|
|
import numpy as np
|
|
import psutil
|
|
|
|
def load_card_play_models(card_play_model_path_dict):
|
|
players = {}
|
|
|
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
|
if card_play_model_path_dict[position] == 'rlcard':
|
|
from .rlcard_agent import RLCardAgent
|
|
players[position] = RLCardAgent(position)
|
|
elif card_play_model_path_dict[position] == 'random':
|
|
from .random_agent import RandomAgent
|
|
players[position] = RandomAgent()
|
|
else:
|
|
from .deep_agent import DeepAgent
|
|
players[position] = DeepAgent(position, card_play_model_path_dict[position], use_onnx=True)
|
|
return players
|
|
|
|
def mp_simulate(card_play_data_list, players, q, output, title):
|
|
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
|
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
|
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
|
env = GameEnv(players)
|
|
for idx, card_play_data in enumerate(card_play_data_list):
|
|
env.card_play_init(card_play_data)
|
|
if output:
|
|
print("\nStart ------- " + title)
|
|
if 'play' in card_play_data.keys():
|
|
print ("".join([EnvCard2RealCard[c] for c in card_play_data['play']["landlord"]]))
|
|
print ("".join([EnvCard2RealCard[c] for c in card_play_data['play']["landlord_down"]]))
|
|
print ("".join([EnvCard2RealCard[c] for c in card_play_data['play']["landlord_front"]]))
|
|
print ("".join([EnvCard2RealCard[c] for c in card_play_data['play']["landlord_up"]]))
|
|
else:
|
|
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord"]]))
|
|
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord_down"]]))
|
|
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord_front"]]))
|
|
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord_up"]]))
|
|
# print(card_play_data)
|
|
count = 0
|
|
while not env.game_over:
|
|
action = env.step()
|
|
if output:
|
|
if count % 4 == 3:
|
|
end = "\n"
|
|
else:
|
|
end = " "
|
|
if len(action) == 0:
|
|
print("Pass", end=end)
|
|
else:
|
|
print("".join([EnvCard2RealCard[c] for c in action]), end=end)
|
|
count+=1
|
|
if output and idx % 10 == 0:
|
|
print("\nindex", idx)
|
|
# print("End -------")
|
|
env.reset()
|
|
|
|
q.put((env.num_wins['landlord'],
|
|
env.num_wins['farmer'],
|
|
env.num_scores['landlord'],
|
|
env.num_scores['farmer']
|
|
))
|
|
|
|
def data_allocation_per_worker(card_play_data_list, num_workers):
|
|
card_play_data_list_each_worker = [[] for k in range(num_workers)]
|
|
for idx, data in enumerate(card_play_data_list):
|
|
card_play_data_list_each_worker[idx % num_workers].append(data)
|
|
|
|
return card_play_data_list_each_worker
|
|
|
|
def evaluate(landlord, landlord_up, landlord_front, landlord_down, eval_data, num_workers, output, title):
|
|
|
|
with open(eval_data, 'rb') as f:
|
|
card_play_data_list = pickle.load(f)
|
|
|
|
card_play_data_list_each_worker = data_allocation_per_worker(
|
|
card_play_data_list, num_workers)
|
|
del card_play_data_list
|
|
|
|
card_play_model_path_dict = {
|
|
'landlord': landlord,
|
|
'landlord_up': landlord_up,
|
|
'landlord_front': landlord_front,
|
|
'landlord_down': landlord_down}
|
|
|
|
num_landlord_wins = 0
|
|
num_farmer_wins = 0
|
|
num_landlord_scores = 0
|
|
num_farmer_scores = 0
|
|
|
|
ctx = mp.get_context('spawn')
|
|
q = ctx.SimpleQueue()
|
|
processes = []
|
|
|
|
players = load_card_play_models(card_play_model_path_dict)
|
|
|
|
for card_paly_data in card_play_data_list_each_worker:
|
|
|
|
p = ctx.Process(
|
|
target=mp_simulate,
|
|
args=(card_paly_data, players, q, output, title))
|
|
p.start()
|
|
|
|
processes.append(p)
|
|
|
|
parent = psutil.Process()
|
|
parent.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
|
|
for child in parent.children():
|
|
child.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
|
|
|
|
for p in processes:
|
|
p.join()
|
|
|
|
for i in range(num_workers):
|
|
result = q.get()
|
|
num_landlord_wins += result[0]
|
|
num_farmer_wins += result[1]
|
|
num_landlord_scores += result[2]
|
|
num_farmer_scores += result[3]
|
|
|
|
num_total_wins = num_landlord_wins + num_farmer_wins
|
|
print('WP results:')
|
|
landlord_wp = num_landlord_wins / num_total_wins
|
|
farmer_wp = num_farmer_wins / num_total_wins
|
|
print('landlord : Farmers - {} : {}'.format(landlord_wp, farmer_wp))
|
|
print('ADP results:')
|
|
landlord_adp = num_landlord_scores / num_total_wins
|
|
farmer_adp = 3 * num_farmer_scores / num_total_wins
|
|
print('landlord : Farmers - {} : {}'.format(landlord_adp, farmer_adp))
|
|
return landlord_wp, farmer_wp, landlord_adp, farmer_adp
|