From fa11fa7cfc615990d4e34e202578e00e6b42dd88 Mon Sep 17 00:00:00 2001 From: Vincentzyx <929403983@qq.com> Date: Wed, 28 Jul 2021 21:05:45 +0800 Subject: [PATCH] CPU Support --- BidHelper.py | 36 ------------------------------------ BidModel.py | 5 ++++- FarmerModel.py | 10 ++++++++-- LandlordModel.py | 5 ++++- requirements.txt | 17 ++++++++++------- 5 files changed, 26 insertions(+), 47 deletions(-) delete mode 100644 BidHelper.py diff --git a/BidHelper.py b/BidHelper.py deleted file mode 100644 index 27a46d0..0000000 --- a/BidHelper.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -# Created by: Vincentzyx -from douzero.env.game import GameEnv -from douzero.evaluation.deep_agent import DeepAgent - -RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7, - '8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12, - 'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30} - -card_play_model_path_dict = { - 'landlord': "baselines/douzero_WP/landlord.ckpt", - 'landlord_up': "baselines/douzero_WP/landlord_up.ckpt", - 'landlord_down': "baselines/douzero_WP/landlord_down.ckpt" -} - -user_position = "landlord" # 玩家角色代码:0-地主上家, 1-地主, 2-地主下家 -ai_players = [0, 0] -ai_players[0] = user_position -ai_players[1] = DeepAgent(user_position, card_play_model_path_dict[user_position]) - -env = GameEnv(ai_players) -card_play_data_list = {} - -def GetWinRate(cards): - env.reset() - card_play_data_list.update({ - 'three_landlord_cards': [RealCard2EnvCard[i] for i in "333"], - 'landlord': [RealCard2EnvCard[i] for i in cards], - 'landlord_up': [RealCard2EnvCard[i] for i in "33333333333333333"], - 'landlord_down': [RealCard2EnvCard[i] for i in "33333333333333333"] - }) - - env.card_play_init(card_play_data_list) - action_message = env.step(user_position) - win_rate = float(action_message["win_rate"].replace("%","")) - return win_rate \ No newline at end of file diff --git a/BidModel.py b/BidModel.py index 19f2e2b..0bef77b 100644 --- a/BidModel.py +++ b/BidModel.py @@ -58,7 +58,10 @@ net.eval() if UseGPU: net = net.to(device) if os.path.exists("bid_weights.pkl"): - net.load_state_dict(torch.load('bid_weights.pkl')) + if torch.cuda.is_available(): + net.load_state_dict(torch.load('bid_weights.pkl')) + else: + net.load_state_dict(torch.load('bid_weights.pkl', map_location=torch.device("cpu"))) def predict(cards): input = RealToOnehot(cards) diff --git a/FarmerModel.py b/FarmerModel.py index 480271a..368334a 100644 --- a/FarmerModel.py +++ b/FarmerModel.py @@ -54,10 +54,16 @@ class Net(nn.Module): Nets = {"up": Net(), "down": Net()} if os.path.exists("landlord_up_weights.pkl"): - Nets["up"].load_state_dict(torch.load("landlord_up_weights.pkl")) + if torch.cuda.is_available(): + Nets["up"].load_state_dict(torch.load("landlord_up_weights.pkl")) + else: + Nets["up"].load_state_dict(torch.load("landlord_up_weights.pkl", map_location=torch.device("cpu"))) Nets["up"].eval() if os.path.exists("landlord_down_weights.pkl"): - Nets["down"].load_state_dict(torch.load("landlord_down_weights.pkl")) + if torch.cuda.is_available(): + Nets["up"].load_state_dict(torch.load("landlord_down_weights.pkl")) + else: + Nets["up"].load_state_dict(torch.load("landlord_down_weights.pkl", map_location=torch.device("cpu"))) Nets["down"].eval() def predict(cards, llc, type="up"): diff --git a/LandlordModel.py b/LandlordModel.py index e9c7bda..48c4532 100644 --- a/LandlordModel.py +++ b/LandlordModel.py @@ -54,7 +54,10 @@ class Net(nn.Module): net = Net() net.eval() if os.path.exists("landlord_weights.pkl"): - net.load_state_dict(torch.load('landlord_weights.pkl')) + if torch.cuda.is_available(): + net.load_state_dict(torch.load('landlord_weights.pkl')) + else: + net.load_state_dict(torch.load('landlord_weights.pkl', map_location=torch.device("cpu"))) else: print("landlord_weights.pkl not found") diff --git a/requirements.txt b/requirements.txt index d9ac571..b38ff11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ -torch>=1.6.0 -GitPython>=3.0.5 -gitdb2>=2.0.6 -PyAutoGUI>=0.9.50 -PyQt5>=5.13.0 -PyQt5-sip>=12.8.1 -rlcard \ No newline at end of file +torch==1.6.0 +GitPython==3.0.5 +gitdb2==2.0.6 +PyAutoGUI==0.9.50 +PyQt5==5.13.0 +PyQt5-sip==12.8.1 +rlcard +pywin32 +matplotlib +opencv-python \ No newline at end of file