diff --git a/.gitignore b/.gitignore index a046d76..76ce144 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ yarn-error.log* db.sqlite3 __pycache__ +*.swp diff --git a/README.md b/README.md index 863fddc..434136d 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,15 @@ The definitions of the fields are as follows: | http://127.0.0.1:8000/tournament/query_payoff | Get all the payoffs | | http://127.0.0.1:8000/tournament/query_payoff?agent0=leduc-holdem-cfr&agent1=leduc-holdem-rule-v1 | Get all the payoffs between rule and CFR models | - +## Regitered Models +Some models have been pre-registered as baselines +| Model | Game | Description | +|----------------------|--------------|---------------------------------------| +| leduc-holdem-random | leduc-holdem | A random model | +| leduc-holdem-cfr | leduc-holdem | Pre-trained CFR model on Leduc Holdem | +| leduc-holdem-rule-v1 | leduc-holdem | A rule model that plays greedily | +| doudizhu-random | doudizhu | A random model | +| doudizhu-rule-v1 | doudizhu | Dou Dizhu rule model | # Others diff --git a/server/tournament/rlcard_wrap/__init__.py b/server/tournament/rlcard_wrap/__init__.py index 721f55e..f7080a4 100644 --- a/server/tournament/rlcard_wrap/__init__.py +++ b/server/tournament/rlcard_wrap/__init__.py @@ -19,6 +19,6 @@ MODEL_IDS['leduc-holdem'] = [ MODEL_IDS['doudizhu'] = [ 'doudizhu-random', - 'doudizhu-random', + 'doudizhu-rule-v1', ] diff --git a/server/tournament/tournament.py b/server/tournament/tournament.py index fb99668..c6820da 100644 --- a/server/tournament/tournament.py +++ b/server/tournament/tournament.py @@ -5,15 +5,6 @@ import numpy as np from .rlcard_wrap import rlcard -def cards2str(cards): - response = '' - for card in cards: - if card.rank == '': - response += card.suit[0] - else: - response += card.rank - return response - class Tournament(object): def __init__(self, game, model_ids, evaluate_num=100): @@ -69,11 +60,7 @@ class Tournament(object): return games_data, payoffs_data def doudizhu_tournament(game, agents, names, num): - import rlcard env = rlcard.make(game, config={'allow_raw_data': True}) - print(env.reset()) - print(env.step(87, False)) - exit() env.set_agents(agents) payoffs = [] json_data = [] @@ -83,9 +70,8 @@ def doudizhu_tournament(game, agents, names, num): roles = ['landlord', 'peasant', 'peasant'] data['playerInfo'] = [{'id': i, 'index': i, 'role': roles[i], 'agentInfo': {'name': names[i]}} for i in range(env.player_num)] state, player_id = env.reset() - #perfect = env.get_perfect_information() - #data['initHands'] = perfect['hand_cards'] - data['initHands'] =[cards2str(env.game.players[i].current_hand) for i in range(env.player_num)] + perfect = env.get_perfect_information() + data['initHands'] = perfect['hand_cards_with_suit'] data['moveHistory'] = [] while not env.is_over(): action, probs = env.agents[player_id].eval_step(state) @@ -97,12 +83,9 @@ def doudizhu_tournament(game, agents, names, num): history['move'] = env._decode_action(action) data['moveHistory'].append(history) - print(action, player_id, env.agents[player_id].use_raw) state, player_id = env.step(action, env.agents[player_id].use_raw) data = json.dumps(data) #data = json.dumps(data, indent=2, sort_keys=True) - print(data) - exit() json_data.append(data) if env.get_payoffs()[0] > 0: wins.append(True)