CPU Support

This commit is contained in:
Vincentzyx 2021-07-28 21:05:45 +08:00
parent cb8a5921eb
commit fa11fa7cfc
5 changed files with 26 additions and 47 deletions

View File

@ -1,36 +0,0 @@
# -*- coding: utf-8 -*-
# Created by: Vincentzyx
from douzero.env.game import GameEnv
from douzero.evaluation.deep_agent import DeepAgent
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
card_play_model_path_dict = {
'landlord': "baselines/douzero_WP/landlord.ckpt",
'landlord_up': "baselines/douzero_WP/landlord_up.ckpt",
'landlord_down': "baselines/douzero_WP/landlord_down.ckpt"
}
user_position = "landlord" # 玩家角色代码0-地主上家, 1-地主, 2-地主下家
ai_players = [0, 0]
ai_players[0] = user_position
ai_players[1] = DeepAgent(user_position, card_play_model_path_dict[user_position])
env = GameEnv(ai_players)
card_play_data_list = {}
def GetWinRate(cards):
env.reset()
card_play_data_list.update({
'three_landlord_cards': [RealCard2EnvCard[i] for i in "333"],
'landlord': [RealCard2EnvCard[i] for i in cards],
'landlord_up': [RealCard2EnvCard[i] for i in "33333333333333333"],
'landlord_down': [RealCard2EnvCard[i] for i in "33333333333333333"]
})
env.card_play_init(card_play_data_list)
action_message = env.step(user_position)
win_rate = float(action_message["win_rate"].replace("%",""))
return win_rate

View File

@ -58,7 +58,10 @@ net.eval()
if UseGPU:
net = net.to(device)
if os.path.exists("bid_weights.pkl"):
net.load_state_dict(torch.load('bid_weights.pkl'))
if torch.cuda.is_available():
net.load_state_dict(torch.load('bid_weights.pkl'))
else:
net.load_state_dict(torch.load('bid_weights.pkl', map_location=torch.device("cpu")))
def predict(cards):
input = RealToOnehot(cards)

View File

@ -54,10 +54,16 @@ class Net(nn.Module):
Nets = {"up": Net(), "down": Net()}
if os.path.exists("landlord_up_weights.pkl"):
Nets["up"].load_state_dict(torch.load("landlord_up_weights.pkl"))
if torch.cuda.is_available():
Nets["up"].load_state_dict(torch.load("landlord_up_weights.pkl"))
else:
Nets["up"].load_state_dict(torch.load("landlord_up_weights.pkl", map_location=torch.device("cpu")))
Nets["up"].eval()
if os.path.exists("landlord_down_weights.pkl"):
Nets["down"].load_state_dict(torch.load("landlord_down_weights.pkl"))
if torch.cuda.is_available():
Nets["up"].load_state_dict(torch.load("landlord_down_weights.pkl"))
else:
Nets["up"].load_state_dict(torch.load("landlord_down_weights.pkl", map_location=torch.device("cpu")))
Nets["down"].eval()
def predict(cards, llc, type="up"):

View File

@ -54,7 +54,10 @@ class Net(nn.Module):
net = Net()
net.eval()
if os.path.exists("landlord_weights.pkl"):
net.load_state_dict(torch.load('landlord_weights.pkl'))
if torch.cuda.is_available():
net.load_state_dict(torch.load('landlord_weights.pkl'))
else:
net.load_state_dict(torch.load('landlord_weights.pkl', map_location=torch.device("cpu")))
else:
print("landlord_weights.pkl not found")

View File

@ -1,7 +1,10 @@
torch>=1.6.0
GitPython>=3.0.5
gitdb2>=2.0.6
PyAutoGUI>=0.9.50
PyQt5>=5.13.0
PyQt5-sip>=12.8.1
rlcard
torch==1.6.0
GitPython==3.0.5
gitdb2==2.0.6
PyAutoGUI==0.9.50
PyQt5==5.13.0
PyQt5-sip==12.8.1
rlcard
pywin32
matplotlib
opencv-python