CPU Support
This commit is contained in:
parent
cb8a5921eb
commit
fa11fa7cfc
36
BidHelper.py
36
BidHelper.py
|
@ -1,36 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Created by: Vincentzyx
|
||||
from douzero.env.game import GameEnv
|
||||
from douzero.evaluation.deep_agent import DeepAgent
|
||||
|
||||
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
|
||||
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
|
||||
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
|
||||
|
||||
card_play_model_path_dict = {
|
||||
'landlord': "baselines/douzero_WP/landlord.ckpt",
|
||||
'landlord_up': "baselines/douzero_WP/landlord_up.ckpt",
|
||||
'landlord_down': "baselines/douzero_WP/landlord_down.ckpt"
|
||||
}
|
||||
|
||||
user_position = "landlord" # 玩家角色代码:0-地主上家, 1-地主, 2-地主下家
|
||||
ai_players = [0, 0]
|
||||
ai_players[0] = user_position
|
||||
ai_players[1] = DeepAgent(user_position, card_play_model_path_dict[user_position])
|
||||
|
||||
env = GameEnv(ai_players)
|
||||
card_play_data_list = {}
|
||||
|
||||
def GetWinRate(cards):
|
||||
env.reset()
|
||||
card_play_data_list.update({
|
||||
'three_landlord_cards': [RealCard2EnvCard[i] for i in "333"],
|
||||
'landlord': [RealCard2EnvCard[i] for i in cards],
|
||||
'landlord_up': [RealCard2EnvCard[i] for i in "33333333333333333"],
|
||||
'landlord_down': [RealCard2EnvCard[i] for i in "33333333333333333"]
|
||||
})
|
||||
|
||||
env.card_play_init(card_play_data_list)
|
||||
action_message = env.step(user_position)
|
||||
win_rate = float(action_message["win_rate"].replace("%",""))
|
||||
return win_rate
|
|
@ -58,7 +58,10 @@ net.eval()
|
|||
if UseGPU:
|
||||
net = net.to(device)
|
||||
if os.path.exists("bid_weights.pkl"):
|
||||
net.load_state_dict(torch.load('bid_weights.pkl'))
|
||||
if torch.cuda.is_available():
|
||||
net.load_state_dict(torch.load('bid_weights.pkl'))
|
||||
else:
|
||||
net.load_state_dict(torch.load('bid_weights.pkl', map_location=torch.device("cpu")))
|
||||
|
||||
def predict(cards):
|
||||
input = RealToOnehot(cards)
|
||||
|
|
|
@ -54,10 +54,16 @@ class Net(nn.Module):
|
|||
|
||||
Nets = {"up": Net(), "down": Net()}
|
||||
if os.path.exists("landlord_up_weights.pkl"):
|
||||
Nets["up"].load_state_dict(torch.load("landlord_up_weights.pkl"))
|
||||
if torch.cuda.is_available():
|
||||
Nets["up"].load_state_dict(torch.load("landlord_up_weights.pkl"))
|
||||
else:
|
||||
Nets["up"].load_state_dict(torch.load("landlord_up_weights.pkl", map_location=torch.device("cpu")))
|
||||
Nets["up"].eval()
|
||||
if os.path.exists("landlord_down_weights.pkl"):
|
||||
Nets["down"].load_state_dict(torch.load("landlord_down_weights.pkl"))
|
||||
if torch.cuda.is_available():
|
||||
Nets["up"].load_state_dict(torch.load("landlord_down_weights.pkl"))
|
||||
else:
|
||||
Nets["up"].load_state_dict(torch.load("landlord_down_weights.pkl", map_location=torch.device("cpu")))
|
||||
Nets["down"].eval()
|
||||
|
||||
def predict(cards, llc, type="up"):
|
||||
|
|
|
@ -54,7 +54,10 @@ class Net(nn.Module):
|
|||
net = Net()
|
||||
net.eval()
|
||||
if os.path.exists("landlord_weights.pkl"):
|
||||
net.load_state_dict(torch.load('landlord_weights.pkl'))
|
||||
if torch.cuda.is_available():
|
||||
net.load_state_dict(torch.load('landlord_weights.pkl'))
|
||||
else:
|
||||
net.load_state_dict(torch.load('landlord_weights.pkl', map_location=torch.device("cpu")))
|
||||
else:
|
||||
print("landlord_weights.pkl not found")
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
torch>=1.6.0
|
||||
GitPython>=3.0.5
|
||||
gitdb2>=2.0.6
|
||||
PyAutoGUI>=0.9.50
|
||||
PyQt5>=5.13.0
|
||||
PyQt5-sip>=12.8.1
|
||||
torch==1.6.0
|
||||
GitPython==3.0.5
|
||||
gitdb2==2.0.6
|
||||
PyAutoGUI==0.9.50
|
||||
PyQt5==5.13.0
|
||||
PyQt5-sip==12.8.1
|
||||
rlcard
|
||||
pywin32
|
||||
matplotlib
|
||||
opencv-python
|
Loading…
Reference in New Issue