Initial commit
|
@ -0,0 +1,2 @@
|
|||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
|
@ -0,0 +1,114 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
|
@ -0,0 +1,3 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
|
@ -0,0 +1,12 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.8" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
|
@ -0,0 +1,35 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<Languages>
|
||||
<language minSize="78" name="Python" />
|
||||
</Languages>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<option name="ignoredErrors">
|
||||
<list>
|
||||
<option value="E501" />
|
||||
<option value="E302" />
|
||||
<option value="E303" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<option name="ignoredErrors">
|
||||
<list>
|
||||
<option value="N803" />
|
||||
<option value="N806" />
|
||||
<option value="N802" />
|
||||
<option value="N801" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="SpellCheckingInspection" enabled="false" level="TYPO" enabled_by_default="false">
|
||||
<option name="processCode" value="true" />
|
||||
<option name="processLiterals" value="true" />
|
||||
<option name="processComments" value="true" />
|
||||
</inspection_tool>
|
||||
<inspection_tool class="SqlNoDataSourceInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
||||
</profile>
|
||||
</component>
|
|
@ -0,0 +1,6 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
|
@ -0,0 +1,4 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/DouZero_For_HappyDouDiZhu-master.iml" filepath="$PROJECT_DIR$/.idea/DouZero_For_HappyDouDiZhu-master.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,36 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Created by: Vincentzyx
|
||||
from douzero.env.game import GameEnv
|
||||
from douzero.evaluation.deep_agent import DeepAgent
|
||||
|
||||
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
|
||||
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
|
||||
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
|
||||
|
||||
card_play_model_path_dict = {
|
||||
'landlord': "baselines/douzero_WP/landlord.ckpt",
|
||||
'landlord_up': "baselines/douzero_WP/landlord_up.ckpt",
|
||||
'landlord_down': "baselines/douzero_WP/landlord_down.ckpt"
|
||||
}
|
||||
|
||||
user_position = "landlord" # 玩家角色代码:0-地主上家, 1-地主, 2-地主下家
|
||||
ai_players = [0, 0]
|
||||
ai_players[0] = user_position
|
||||
ai_players[1] = DeepAgent(user_position, card_play_model_path_dict[user_position])
|
||||
|
||||
env = GameEnv(ai_players)
|
||||
card_play_data_list = {}
|
||||
|
||||
def GetWinRate(cards):
|
||||
env.reset()
|
||||
card_play_data_list.update({
|
||||
'three_landlord_cards': [RealCard2EnvCard[i] for i in "333"],
|
||||
'landlord': [RealCard2EnvCard[i] for i in cards],
|
||||
'landlord_up': [RealCard2EnvCard[i] for i in "33333333333333333"],
|
||||
'landlord_down': [RealCard2EnvCard[i] for i in "33333333333333333"]
|
||||
})
|
||||
|
||||
env.card_play_init(card_play_data_list)
|
||||
action_message = env.step(user_position)
|
||||
win_rate = float(action_message["win_rate"].replace("%",""))
|
||||
return win_rate
|
|
@ -0,0 +1,69 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Created by: Vincentzyx
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
import time
|
||||
|
||||
|
||||
def EnvToOnehot(cards):
|
||||
Env2IdxMap = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8,12:9,13:10,14:11,17:12,20:13,30:14}
|
||||
cards = [Env2IdxMap[i] for i in cards]
|
||||
Onehot = torch.zeros((4,15))
|
||||
for i in range(0, 15):
|
||||
Onehot[:cards.count(i),i] = 1
|
||||
return Onehot
|
||||
|
||||
def RealToOnehot(cards):
|
||||
RealCard2EnvCard = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4,
|
||||
'8': 5, '9': 6, 'T': 7, 'J': 8, 'Q': 9,
|
||||
'K': 10, 'A': 11, '2': 12, 'X': 13, 'D': 14}
|
||||
cards = [RealCard2EnvCard[c] for c in cards]
|
||||
Onehot = torch.zeros((4,15))
|
||||
for i in range(0, 15):
|
||||
Onehot[:cards.count(i),i] = 1
|
||||
return Onehot
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.fc1 = nn.Linear(60, 512)
|
||||
self.fc2 = nn.Linear(512, 512)
|
||||
self.fc3 = nn.Linear(512, 512)
|
||||
self.fc4 = nn.Linear(512, 512)
|
||||
self.fc5 = nn.Linear(512, 512)
|
||||
self.fc6 = nn.Linear(512, 1)
|
||||
self.dropout5 = nn.Dropout(0.5)
|
||||
self.dropout3 = nn.Dropout(0.3)
|
||||
self.dropout1 = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.fc1(input)
|
||||
x = torch.relu(self.dropout1(self.fc2(x)))
|
||||
x = torch.relu(self.dropout3(self.fc3(x)))
|
||||
x = torch.relu(self.dropout5(self.fc4(x)))
|
||||
x = torch.relu(self.dropout5(self.fc5(x)))
|
||||
x = self.fc6(x)
|
||||
return x
|
||||
|
||||
|
||||
UseGPU = False
|
||||
device = torch.device('cuda:0')
|
||||
net = Net()
|
||||
net.eval()
|
||||
if UseGPU:
|
||||
net = net.to(device)
|
||||
if os.path.exists("bid_weights.pkl"):
|
||||
net.load_state_dict(torch.load('bid_weights.pkl'))
|
||||
|
||||
def predict(cards):
|
||||
input = RealToOnehot(cards)
|
||||
if UseGPU:
|
||||
input = input.to(device)
|
||||
input = torch.flatten(input)
|
||||
win_rate = net(input)
|
||||
return win_rate[0].item() * 100
|
|
@ -0,0 +1,67 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Created by: Vincentzyx
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
import time
|
||||
|
||||
|
||||
def EnvToOnehot(cards):
|
||||
Env2IdxMap = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8,12:9,13:10,14:11,17:12,20:13,30:14}
|
||||
cards = [Env2IdxMap[i] for i in cards]
|
||||
Onehot = torch.zeros((4,15))
|
||||
for i in range(0, 15):
|
||||
Onehot[:cards.count(i),i] = 1
|
||||
return Onehot
|
||||
|
||||
def RealToOnehot(cards, llc):
|
||||
RealCard2EnvCard = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4,
|
||||
'8': 5, '9': 6, 'T': 7, 'J': 8, 'Q': 9,
|
||||
'K': 10, 'A': 11, '2': 12, 'X': 13, 'D': 14}
|
||||
cards = [RealCard2EnvCard[c] for c in cards]
|
||||
llcs = [RealCard2EnvCard[c] for c in llc]
|
||||
Onehot = torch.zeros((7,15))
|
||||
for i in range(0, 15):
|
||||
Onehot[:cards.count(i),i] = 1
|
||||
Onehot[4:llcs.count(i)+4,i] = 1
|
||||
return Onehot
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(105, 512)
|
||||
self.fc2 = nn.Linear(512, 512)
|
||||
self.fc3 = nn.Linear(512, 512)
|
||||
self.fc4 = nn.Linear(512, 512)
|
||||
self.fc5 = nn.Linear(512, 512)
|
||||
self.fc6 = nn.Linear(512, 1)
|
||||
self.dropout5 = nn.Dropout(0.5)
|
||||
self.dropout3 = nn.Dropout(0.3)
|
||||
self.dropout1 = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.fc1(input)
|
||||
x = torch.relu(self.dropout3(self.fc2(x)))
|
||||
x = torch.relu(self.dropout5(self.fc3(x)))
|
||||
x = torch.relu(self.dropout5(self.fc4(x)))
|
||||
x = torch.relu(self.dropout5(self.fc5(x)))
|
||||
x = self.fc6(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
Nets = {"up": Net(), "down": Net()}
|
||||
if os.path.exists("landlord_up_weights.pkl"):
|
||||
Nets["up"].load_state_dict(torch.load("landlord_up_weights.pkl"))
|
||||
Nets["up"].eval()
|
||||
if os.path.exists("landlord_down_weights.pkl"):
|
||||
Nets["down"].load_state_dict(torch.load("landlord_down_weights.pkl"))
|
||||
Nets["down"].eval()
|
||||
|
||||
def predict(cards, llc, type="up"):
|
||||
net = Nets[type]
|
||||
x = torch.flatten(RealToOnehot(cards, llc))
|
||||
y = net(x)[0].item()
|
||||
return y * 100
|
|
@ -0,0 +1,452 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Created by: Vincentzyx
|
||||
import win32gui
|
||||
import win32ui
|
||||
import win32api
|
||||
from ctypes import windll
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import pyautogui
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
from win32con import WM_LBUTTONDOWN, MK_LBUTTON, WM_LBUTTONUP, WM_MOUSEMOVE
|
||||
import multiprocessing as mp
|
||||
|
||||
from PyQt5 import QtGui, QtWidgets, QtCore
|
||||
from PyQt5.QtCore import QTime, QEventLoop
|
||||
|
||||
Pics = {}
|
||||
ReqQueue = mp.Queue()
|
||||
ResultQueue = mp.Queue()
|
||||
Processes = []
|
||||
|
||||
def GetSingleCardQueue(reqQ, resQ, Pics):
|
||||
while True:
|
||||
while not reqQ.empty():
|
||||
image, i, sx, sy, sw, sh, checkSelect = reqQ.get()
|
||||
result = GetSingleCard(image, i, sx, sy, sw, sh, checkSelect, Pics)
|
||||
del image
|
||||
if result is not None:
|
||||
resQ.put(result)
|
||||
time.sleep(0.01)
|
||||
|
||||
def ShowImg(image):
|
||||
plt.imshow(image)
|
||||
plt.show()
|
||||
|
||||
def DrawRectWithText(image, rect, text):
|
||||
img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||
x, y, w, h = rect
|
||||
img2 = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 0, 255), 2)
|
||||
img2 = cv2.putText(img2, text, (x, y + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
|
||||
return Image.fromarray(cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))
|
||||
|
||||
def CompareCard(card):
|
||||
order = {"3": 0, "4": 1, "5": 2, "6": 3, "7": 4, "8": 5, "9": 6, "T": 7, "J": 8, "Q": 9, "K": 10, "A": 11, "2": 12,
|
||||
"X": 13, "D": 14}
|
||||
return order[card]
|
||||
|
||||
def CompareCardInfo(card):
|
||||
order = {"3": 0, "4": 1, "5": 2, "6": 3, "7": 4, "8": 5, "9": 6, "T": 7, "J": 8, "Q": 9, "K": 10, "A": 11, "2": 12,
|
||||
"X": 13, "D": 14}
|
||||
return order[card[0]]
|
||||
|
||||
def CompareCards(cards1, cards2):
|
||||
if len(cards1) != len(cards2):
|
||||
return False
|
||||
cards1.sort(key=CompareCard)
|
||||
cards2.sort(key=CompareCard)
|
||||
for i in range(0, len(cards1)):
|
||||
if cards1[i] != cards2[i]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def GetListDifference(l1, l2):
|
||||
temp1 = []
|
||||
temp1.extend(l1)
|
||||
temp2 = []
|
||||
temp2.extend(l2)
|
||||
for i in l2:
|
||||
if i in temp1:
|
||||
temp1.remove(i)
|
||||
for i in l1:
|
||||
if i in temp2:
|
||||
temp2.remove(i)
|
||||
return temp1, temp2
|
||||
|
||||
def FindImage(fromImage, template, threshold=0.9):
|
||||
w, h, _ = template.shape
|
||||
fromImage = cv2.cvtColor(np.asarray(fromImage), cv2.COLOR_RGB2BGR)
|
||||
res = cv2.matchTemplate(fromImage, template, cv2.TM_CCOEFF_NORMED)
|
||||
loc = np.where(res >= threshold)
|
||||
points = []
|
||||
for pt in zip(*loc[::-1]):
|
||||
points.append(pt)
|
||||
return points
|
||||
|
||||
def GetSingleCard(image, i, sx, sy, sw, sh, checkSelect, Pics):
|
||||
cardSearchFrom = 0
|
||||
AllCardsNC = ['rD', 'bX', '2', 'A', 'K', 'Q', 'J', 'T', '9', '8', '7', '6', '5', '4', '3']
|
||||
currCard = ""
|
||||
ci = cardSearchFrom
|
||||
while ci < len(AllCardsNC):
|
||||
if "r" in AllCardsNC[ci] or "b" in AllCardsNC[ci]:
|
||||
result = pyautogui.locate(needleImage=Pics["m" + AllCardsNC[ci]], haystackImage=image,
|
||||
region=(sx + 50 * i, sy - checkSelect * 25, sw, sh), confidence=0.9)
|
||||
if result is not None:
|
||||
cardPos = (sx + 50 * i + sw // 2, sy - checkSelect * 25 + sh // 2)
|
||||
cardSearchFrom = ci
|
||||
currCard = AllCardsNC[ci][1]
|
||||
cardInfo = (currCard, cardPos)
|
||||
return cardInfo
|
||||
break
|
||||
else:
|
||||
outerBreak = False
|
||||
for card_type in ["r", "b"]:
|
||||
result = pyautogui.locate(needleImage=Pics["m" + card_type + AllCardsNC[ci]],
|
||||
haystackImage=image,
|
||||
region=(sx + 50 * i, sy - checkSelect * 25, sw, sh), confidence=0.9)
|
||||
if result is not None:
|
||||
cardPos = (sx + 50 * i + sw // 2, sy - checkSelect * 25 + sh // 2)
|
||||
cardSearchFrom = ci
|
||||
currCard = AllCardsNC[ci]
|
||||
cardInfo = (currCard, cardPos)
|
||||
outerBreak = True
|
||||
return cardInfo
|
||||
break
|
||||
if outerBreak:
|
||||
break
|
||||
if ci == len(AllCardsNC) - 1 and checkSelect == 0:
|
||||
checkSelect = 1
|
||||
ci = cardSearchFrom - 1
|
||||
ci += 1
|
||||
return None
|
||||
|
||||
def RunThreads():
|
||||
for file in os.listdir("pics"):
|
||||
info = file.split(".")
|
||||
if info[1] == "png":
|
||||
tmpImage = Image.open("pics/" + file)
|
||||
Pics.update({info[0]: tmpImage})
|
||||
for ti in range(20):
|
||||
p = mp.Process(target=GetSingleCardQueue, args=(ReqQueue, ResultQueue, Pics))
|
||||
p.start()
|
||||
|
||||
|
||||
def LocateOnImage(image, template, region=None, confidence=0.9):
|
||||
if region is not None:
|
||||
x, y, w, h = region
|
||||
imgShape = image.shape
|
||||
image = image[y:y+h, x:x+w,:]
|
||||
res = cv2.matchTemplate(image, template, cv2.TM_CCOEFF_NORMED)
|
||||
if (res >= confidence).any():
|
||||
return True
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class GameHelper:
|
||||
def __init__(self):
|
||||
self.ScreenZoomRate = 1.25
|
||||
self.Pics = {}
|
||||
self.PicsCV = {}
|
||||
self.Handle = win32gui.FindWindow("Hlddz", None)
|
||||
self.Interrupt = False
|
||||
for file in os.listdir("pics"):
|
||||
info = file.split(".")
|
||||
if info[1] == "png":
|
||||
tmpImage = Image.open("pics/" + file)
|
||||
imgCv = cv2.imread("pics/" + file)
|
||||
self.Pics.update({info[0]: tmpImage})
|
||||
self.PicsCV.update({info[0]: imgCv})
|
||||
|
||||
def Screenshot(self, region=None): # -> (im, (left, top))
|
||||
hwnd = self.Handle
|
||||
# im = Image.open(r"C:\Users\q9294\Desktop\llc.png")
|
||||
# im = im.resize((1796, 1047))
|
||||
# return im, (0,0)
|
||||
left, top, right, bot = win32gui.GetWindowRect(hwnd)
|
||||
width = right - left
|
||||
height = bot - top
|
||||
width = int(width / self.ScreenZoomRate)
|
||||
height = int(height / self.ScreenZoomRate)
|
||||
hwndDC = win32gui.GetWindowDC(hwnd)
|
||||
mfcDC = win32ui.CreateDCFromHandle(hwndDC)
|
||||
saveDC = mfcDC.CreateCompatibleDC()
|
||||
saveBitMap = win32ui.CreateBitmap()
|
||||
saveBitMap.CreateCompatibleBitmap(mfcDC, width, height)
|
||||
saveDC.SelectObject(saveBitMap)
|
||||
result = windll.user32.PrintWindow(hwnd, saveDC.GetSafeHdc(), 0)
|
||||
bmpinfo = saveBitMap.GetInfo()
|
||||
bmpstr = saveBitMap.GetBitmapBits(True)
|
||||
im = Image.frombuffer(
|
||||
"RGB",
|
||||
(bmpinfo['bmWidth'], bmpinfo['bmHeight']),
|
||||
bmpstr, 'raw', 'BGRX', 0, 1)
|
||||
win32gui.DeleteObject(saveBitMap.GetHandle())
|
||||
saveDC.DeleteDC()
|
||||
mfcDC.DeleteDC()
|
||||
win32gui.ReleaseDC(hwnd, hwndDC)
|
||||
im = im.resize((1796, 1047))
|
||||
if region is not None:
|
||||
im = im.crop((region[0], region[1], region[0] + region[2], region[1] + region[3]))
|
||||
if result:
|
||||
return im, (left, top)
|
||||
else:
|
||||
return None, (0, 0)
|
||||
|
||||
def LocateOnScreen(self, templateName, region, confidence=0.9):
|
||||
image, _ = self.Screenshot()
|
||||
return pyautogui.locate(needleImage=self.Pics[templateName],
|
||||
haystackImage=image, region=region, confidence=confidence)
|
||||
|
||||
def ClickOnImage(self, templateName, region=None, confidence=0.9):
|
||||
image, _ = self.Screenshot()
|
||||
result = pyautogui.locate(needleImage=self.Pics[templateName], haystackImage=image, confidence=confidence, region=region)
|
||||
if result is not None:
|
||||
self.LeftClick((result[0],result[1]))
|
||||
|
||||
def GetCardsState(self, image):
|
||||
st = time.time()
|
||||
states = []
|
||||
cardStartPos = pyautogui.locate(needleImage=self.Pics["card_edge"], haystackImage=image,
|
||||
region=(313, 747, 1144, 200), confidence=0.85)
|
||||
if cardStartPos is None:
|
||||
return []
|
||||
sx = cardStartPos[0] + 10
|
||||
cardSearchFrom = 0
|
||||
sy, sw, sh = 770, 50, 55
|
||||
for i in range(0, 20):
|
||||
haveWhite = pyautogui.locate(needleImage=self.Pics["card_white"], haystackImage=image,
|
||||
region=(sx + 50 * i, sy, 50, 50), confidence=0.8)
|
||||
if haveWhite is not None:
|
||||
break
|
||||
result = pyautogui.locate(needleImage=self.Pics["card_upper_edge"], haystackImage=image,
|
||||
region=(sx + 50 * i, 720, sw, 38), confidence=0.9)
|
||||
checkSelect = 0
|
||||
if result is not None:
|
||||
result = pyautogui.locate(needleImage=self.Pics['card_overlap'], haystackImage=image,
|
||||
region=(sx + 50 * i, 750, sw, 38), confidence=0.85)
|
||||
if result is None:
|
||||
checkSelect = 1
|
||||
states.append(checkSelect)
|
||||
print("GetStates Costs ", time.time()-st)
|
||||
return states
|
||||
|
||||
def GetCardsMulti(self, image):
|
||||
st = time.time()
|
||||
cardStartPos = pyautogui.locate(needleImage=self.Pics["card_edge"], haystackImage=image,
|
||||
region=(313, 747, 1144, 200), confidence=0.85)
|
||||
if cardStartPos is None:
|
||||
return [],[]
|
||||
sx = cardStartPos[0] + 10
|
||||
AllCardsNC = ['rD', 'bX', '2', 'A', 'K', 'Q', 'J', 'T', '9', '8', '7', '6', '5', '4', '3']
|
||||
hand_cards = []
|
||||
select_map = []
|
||||
cardSearchFrom = 0
|
||||
sy, sw, sh = 770, 50, 55
|
||||
for i in range(0, 20):
|
||||
haveWhite = pyautogui.locate(needleImage=self.Pics["card_white"], haystackImage=image,
|
||||
region=(sx + 50 * i, sy, 60, 60), confidence=0.8)
|
||||
if haveWhite is not None:
|
||||
break
|
||||
result = pyautogui.locate(needleImage=self.Pics["card_upper_edge"], haystackImage=image,
|
||||
region=(sx + 50 * i, 720, sw, 50), confidence=0.9)
|
||||
checkSelect = 0
|
||||
if result is not None:
|
||||
result = pyautogui.locate(needleImage=self.Pics['card_overlap'], haystackImage=image,
|
||||
region=(sx + 50 * i, 750, sw, 50), confidence=0.85)
|
||||
if result is None:
|
||||
checkSelect = 1
|
||||
select_map.append(checkSelect)
|
||||
ReqQueue.put((image, i, sx, sy, sw, sh, checkSelect))
|
||||
QtWidgets.QApplication.processEvents(QEventLoop.AllEvents, 10)
|
||||
st = time.time()
|
||||
while len(hand_cards) != len(select_map):
|
||||
while not ResultQueue.empty():
|
||||
hand_cards.append(ResultQueue.get())
|
||||
time.sleep(0.01)
|
||||
QtWidgets.QApplication.processEvents(QEventLoop.AllEvents, 10)
|
||||
hand_cards.sort(key=CompareCardInfo, reverse=True)
|
||||
print("GetCardsMP Costs ", time.time()-st)
|
||||
return hand_cards, select_map
|
||||
|
||||
def GetCards(self, image):
|
||||
st = time.time()
|
||||
imgCv = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||
cardStartPos = pyautogui.locate(needleImage=self.Pics["card_edge"], haystackImage=image,
|
||||
region=(313, 747, 1144, 200), confidence=0.85)
|
||||
if cardStartPos is None:
|
||||
return [],[]
|
||||
sx = cardStartPos[0] + 10
|
||||
AllCardsNC = ['rD', 'bX', '2', 'A', 'K', 'Q', 'J', 'T', '9', '8', '7', '6', '5', '4', '3']
|
||||
hand_cards = []
|
||||
select_map = []
|
||||
cardSearchFrom = 0
|
||||
sy, sw, sh = 770, 50, 55
|
||||
for i in range(0, 20):
|
||||
haveWhite = pyautogui.locate(needleImage=self.Pics["card_white"], haystackImage=image,
|
||||
region=(sx + 50 * i, sy, 60, 60), confidence=0.8)
|
||||
if haveWhite is not None:
|
||||
break
|
||||
result = pyautogui.locate(needleImage=self.Pics["card_upper_edge"], haystackImage=image,
|
||||
region=(sx + 50 * i, 720, sw, 50), confidence=0.9)
|
||||
checkSelect = 0
|
||||
if result is not None:
|
||||
result = pyautogui.locate(needleImage=self.Pics['card_overlap'], haystackImage=image,
|
||||
region=(sx + 50 * i, 750, sw, 50), confidence=0.85)
|
||||
if result is None:
|
||||
checkSelect = 1
|
||||
select_map.append(checkSelect)
|
||||
currCard = ""
|
||||
ci = cardSearchFrom
|
||||
while ci < len(AllCardsNC):
|
||||
if "r" in AllCardsNC[ci] or "b" in AllCardsNC[ci]:
|
||||
result = LocateOnImage(imgCv, self.PicsCV["m" + AllCardsNC[ci]], region=(sx + 50 * i, sy - checkSelect * 25, sw, sh), confidence=0.91)
|
||||
# result = pyautogui.locate(needleImage=self.Pics["m" + AllCardsNC[ci]], haystackImage=image,
|
||||
# region=(sx + 50 * i, sy - checkSelect * 25, sw, sh), confidence=0.9)
|
||||
if result is not None:
|
||||
cardPos = (sx + 50 * i + sw // 2, sy - checkSelect * 25 + sh // 2)
|
||||
cardSearchFrom = ci
|
||||
currCard = AllCardsNC[ci][1]
|
||||
cardInfo = (currCard, cardPos)
|
||||
hand_cards.append(cardInfo)
|
||||
else:
|
||||
outerBreak = False
|
||||
for card_type in ["r", "b"]:
|
||||
result = LocateOnImage(imgCv, self.PicsCV["m" + card_type + AllCardsNC[ci]], region=(sx + 50 * i, sy - checkSelect * 25, sw, sh), confidence=0.91)
|
||||
# result = pyautogui.locate(needleImage=self.Pics["m" + card_type + AllCardsNC[ci]],
|
||||
# haystackImage=image,
|
||||
# region=(sx + 50 * i, sy - checkSelect * 25, sw, sh), confidence=0.9)
|
||||
if result is not None:
|
||||
cardPos = (sx + 50 * i + sw // 2, sy - checkSelect * 25 + sh // 2)
|
||||
cardSearchFrom = ci
|
||||
currCard = AllCardsNC[ci]
|
||||
cardInfo = (currCard, cardPos)
|
||||
hand_cards.append(cardInfo)
|
||||
outerBreak = True
|
||||
break
|
||||
if outerBreak:
|
||||
break
|
||||
if ci == len(AllCardsNC) - 1 and checkSelect == 0:
|
||||
checkSelect = 1
|
||||
ci = cardSearchFrom - 1
|
||||
ci += 1
|
||||
QtWidgets.QApplication.processEvents(QEventLoop.AllEvents, 10)
|
||||
print("GetCards Costs ", time.time()-st)
|
||||
return hand_cards, select_map
|
||||
|
||||
def LeftClick(self, pos):
|
||||
x, y = pos
|
||||
lParam = win32api.MAKELONG(x, y)
|
||||
win32gui.PostMessage(self.Handle, WM_MOUSEMOVE, MK_LBUTTON, lParam)
|
||||
win32gui.PostMessage(self.Handle, WM_LBUTTONDOWN, MK_LBUTTON, lParam)
|
||||
win32gui.PostMessage(self.Handle, WM_LBUTTONUP, MK_LBUTTON, lParam)
|
||||
|
||||
def SelectCards(self, cards):
|
||||
cards = [card for card in cards]
|
||||
tobeSelected = []
|
||||
tobeSelected.extend(cards)
|
||||
image, windowPos = self.Screenshot()
|
||||
handCardsInfo, states = self.GetCards(image)
|
||||
cardSelectMap = []
|
||||
for card in handCardsInfo:
|
||||
c = card[0]
|
||||
if c in tobeSelected:
|
||||
cardSelectMap.append(1)
|
||||
tobeSelected.remove(c)
|
||||
else:
|
||||
cardSelectMap.append(0)
|
||||
clickMap = []
|
||||
handcards = [c[0] for c in handCardsInfo]
|
||||
for i in range(0, len(cardSelectMap)):
|
||||
if cardSelectMap[i] == states[i]:
|
||||
clickMap.append(0)
|
||||
else:
|
||||
clickMap.append(1)
|
||||
while 1 in clickMap:
|
||||
for i in range(0, len(clickMap)):
|
||||
if clickMap[i] == 1:
|
||||
self.LeftClick(handCardsInfo[i][1])
|
||||
break
|
||||
time.sleep(0.1)
|
||||
if self.Interrupt:
|
||||
break
|
||||
image, _ = self.Screenshot()
|
||||
states = self.GetCardsState(image)
|
||||
clickMap = []
|
||||
for i in range(0, len(cardSelectMap)):
|
||||
if cardSelectMap[i] == states[i]:
|
||||
clickMap.append(0)
|
||||
else:
|
||||
clickMap.append(1)
|
||||
QtWidgets.QApplication.processEvents(QEventLoop.AllEvents, 10)
|
||||
|
||||
|
||||
# for file in os.listdir("pics"):
|
||||
# info = file.split(".")
|
||||
# if info[1] == "png":
|
||||
# tmpImage = Image.open("pics/" + file)
|
||||
# imgBGR = cv2.imread("pics/" + file)
|
||||
# Pics.update({info[0]: tmpImage})
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.freeze_support()
|
||||
class A:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
Pics = {}
|
||||
PicsCV = {}
|
||||
Handle = win32gui.FindWindow("Hlddz", None)
|
||||
form = A()
|
||||
form.MyHandCardsPos = (250, 764, 1141, 70) # 我的截图区域
|
||||
form.LPlayedCardsPos = (463, 355, 380, 250) # 左边截图区域
|
||||
form.RPlayedCardsPos = (946, 355, 380, 250) # 右边截图区域
|
||||
form.LandlordFlagPos = [(1281, 276, 110, 140), (267, 695, 110, 140), (424, 237, 110, 140)] # 地主标志截图区域(右-我-左)
|
||||
form.ThreeLandlordCardsPos = (753, 32, 287, 136) # 地主底牌截图区域,resize成349x168
|
||||
form.PassBtnPoss = (686, 659, 419, 100)
|
||||
GameHelper = GameHelper()
|
||||
# img, _ = GameHelper.Screenshot()
|
||||
img = Image.open(r"C:\Users\q9294\Desktop\cardselect.png")
|
||||
img2 = Image.open(r"pics/card_corner.png")
|
||||
img = img.resize((1796, 1047))
|
||||
# imgcv = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
|
||||
# st = time.time()
|
||||
# re = LocateOnImage(imgcv, GameHelper.PicsCV["card_edge"], region=None, confidence=0.999)
|
||||
# print(re)
|
||||
# print(time.time()-st)
|
||||
# st = time.time()
|
||||
# re = pyautogui.locate(needleImage=GameHelper.Pics["card_edge"], haystackImage=img, confidence=0.9)
|
||||
# print(re)
|
||||
# print(time.time()-st)
|
||||
# st = time.time()
|
||||
img, _ = GameHelper.Screenshot()
|
||||
cards, _= GameHelper.GetCards(img)
|
||||
# cards = "".join([i[0] for i in cards])
|
||||
print(cards)
|
||||
print(len(cards))
|
||||
# et = time.time()
|
||||
# print(et - st)
|
||||
# pos2 = pyautogui.locate(needleImage=Pics["card_edge"], haystackImage=img, confidence=0.9)
|
||||
# pos = FindImage(img, PicsCV["card_corner"], threshold=0.7)
|
||||
# print(pos)
|
||||
# print(pos2)
|
||||
# for p in pos:
|
||||
# img = DrawRectWithText(img, (p[0], p[1], 50, 50), "p1")
|
||||
# img = DrawRectWithText(img, (pos2[0], pos2[1], 50, 50), "p2")
|
||||
# img = DrawRectWithText(img, (sx+50*i, sy-checkSelect*25,sw,sh), c)
|
||||
# img = DrawRectWithText(img, form.LPlayedCardsPos, "LPlayed")
|
||||
# img = DrawRectWithText(img, form.RPlayedCardsPos, "RPlayed")
|
||||
# img = DrawRectWithText(img, form.MyHandCardsPos, "MyCard")
|
||||
# img = DrawRectWithText(img, form.ThreeLandlordCardsPos, "ThreeLLCPos")
|
||||
# img = DrawRectWithText(img, form.LandlordFlagPos[0], "RFlag")
|
||||
# img = DrawRectWithText(img, form.LandlordFlagPos[1], "MyFlag")
|
||||
# img = DrawRectWithText(img, form.LandlordFlagPos[2], "LFlag")
|
||||
# img = DrawRectWithText(img, form.PassBtnPoss, "Btns")
|
||||
# ShowImg(img)
|
||||
exit()
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,64 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Created by: Vincentzyx
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
import time
|
||||
|
||||
|
||||
def EnvToOnehot(cards):
|
||||
Env2IdxMap = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8,12:9,13:10,14:11,17:12,20:13,30:14}
|
||||
cards = [Env2IdxMap[i] for i in cards]
|
||||
Onehot = torch.zeros((4,15))
|
||||
for i in range(0, 15):
|
||||
Onehot[:cards.count(i),i] = 1
|
||||
return Onehot
|
||||
|
||||
def RealToOnehot(cards):
|
||||
RealCard2EnvCard = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4,
|
||||
'8': 5, '9': 6, 'T': 7, 'J': 8, 'Q': 9,
|
||||
'K': 10, 'A': 11, '2': 12, 'X': 13, 'D': 14}
|
||||
cards = [RealCard2EnvCard[c] for c in cards]
|
||||
Onehot = torch.zeros((4,15))
|
||||
for i in range(0, 15):
|
||||
Onehot[:cards.count(i),i] = 1
|
||||
return Onehot
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.fc1 = nn.Linear(60, 512)
|
||||
self.fc2 = nn.Linear(512, 512)
|
||||
self.fc3 = nn.Linear(512, 512)
|
||||
self.fc4 = nn.Linear(512, 512)
|
||||
self.fc5 = nn.Linear(512, 512)
|
||||
self.fc6 = nn.Linear(512, 1)
|
||||
self.dropout5 = nn.Dropout(0.5)
|
||||
self.dropout3 = nn.Dropout(0.3)
|
||||
self.dropout1 = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.fc1(input)
|
||||
x = torch.relu(self.dropout3(self.fc2(x)))
|
||||
x = torch.relu(self.dropout5(self.fc3(x)))
|
||||
x = torch.relu(self.dropout5(self.fc4(x)))
|
||||
x = torch.relu(self.dropout5(self.fc5(x)))
|
||||
x = self.fc6(x)
|
||||
return x
|
||||
|
||||
|
||||
net = Net()
|
||||
net.eval()
|
||||
if os.path.exists("landlord_weights.pkl"):
|
||||
net.load_state_dict(torch.load('landlord_weights.pkl'))
|
||||
else:
|
||||
print("landlord_weights.pkl not found")
|
||||
|
||||
def predict(cards):
|
||||
cards_onehot = torch.flatten(RealToOnehot(cards))
|
||||
y_predict = net(cards_onehot)
|
||||
return y_predict[0].item() * 100
|
|
@ -0,0 +1,217 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Form implementation generated from reading ui file 'MainWindow.ui'
|
||||
#
|
||||
# Created by: PyQt5 UI code generator 5.15.4
|
||||
#
|
||||
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
|
||||
# run again. Do not edit this file unless you know what you are doing.
|
||||
|
||||
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
|
||||
|
||||
class Ui_Form(object):
|
||||
def setupUi(self, Form):
|
||||
Form.setObjectName("Form")
|
||||
Form.resize(703, 421)
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Arial")
|
||||
font.setPointSize(9)
|
||||
font.setBold(True)
|
||||
font.setItalic(False)
|
||||
font.setWeight(75)
|
||||
Form.setFont(font)
|
||||
Form.setWindowOpacity(0.8)
|
||||
self.WinRate = QtWidgets.QLabel(Form)
|
||||
self.WinRate.setGeometry(QtCore.QRect(480, 150, 201, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(14)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.WinRate.setFont(font)
|
||||
self.WinRate.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.WinRate.setObjectName("WinRate")
|
||||
self.InitCard = QtWidgets.QPushButton(Form)
|
||||
self.InitCard.setGeometry(QtCore.QRect(80, 360, 121, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(14)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.InitCard.setFont(font)
|
||||
self.InitCard.setStyleSheet("")
|
||||
self.InitCard.setObjectName("InitCard")
|
||||
self.UserHandCards = QtWidgets.QLabel(Form)
|
||||
self.UserHandCards.setGeometry(QtCore.QRect(40, 160, 421, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(14)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.UserHandCards.setFont(font)
|
||||
self.UserHandCards.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.UserHandCards.setObjectName("UserHandCards")
|
||||
self.LPlayer = QtWidgets.QFrame(Form)
|
||||
self.LPlayer.setGeometry(QtCore.QRect(20, 80, 201, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(9)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.LPlayer.setFont(font)
|
||||
self.LPlayer.setFrameShape(QtWidgets.QFrame.StyledPanel)
|
||||
self.LPlayer.setFrameShadow(QtWidgets.QFrame.Raised)
|
||||
self.LPlayer.setObjectName("LPlayer")
|
||||
self.LPlayedCard = QtWidgets.QLabel(self.LPlayer)
|
||||
self.LPlayedCard.setGeometry(QtCore.QRect(0, 0, 201, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(14)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.LPlayedCard.setFont(font)
|
||||
self.LPlayedCard.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.LPlayedCard.setObjectName("LPlayedCard")
|
||||
self.RPlayer = QtWidgets.QFrame(Form)
|
||||
self.RPlayer.setGeometry(QtCore.QRect(250, 80, 201, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(16)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.RPlayer.setFont(font)
|
||||
self.RPlayer.setFrameShape(QtWidgets.QFrame.StyledPanel)
|
||||
self.RPlayer.setFrameShadow(QtWidgets.QFrame.Raised)
|
||||
self.RPlayer.setObjectName("RPlayer")
|
||||
self.RPlayedCard = QtWidgets.QLabel(self.RPlayer)
|
||||
self.RPlayedCard.setGeometry(QtCore.QRect(0, 0, 201, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(14)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.RPlayedCard.setFont(font)
|
||||
self.RPlayedCard.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.RPlayedCard.setObjectName("RPlayedCard")
|
||||
self.Player = QtWidgets.QFrame(Form)
|
||||
self.Player.setGeometry(QtCore.QRect(480, 80, 201, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(9)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.Player.setFont(font)
|
||||
self.Player.setFrameShape(QtWidgets.QFrame.StyledPanel)
|
||||
self.Player.setFrameShadow(QtWidgets.QFrame.Raised)
|
||||
self.Player.setObjectName("Player")
|
||||
self.PredictedCard = QtWidgets.QLabel(self.Player)
|
||||
self.PredictedCard.setGeometry(QtCore.QRect(0, 0, 201, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(14)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.PredictedCard.setFont(font)
|
||||
self.PredictedCard.setStyleSheet("")
|
||||
self.PredictedCard.setFrameShape(QtWidgets.QFrame.Panel)
|
||||
self.PredictedCard.setLineWidth(1)
|
||||
self.PredictedCard.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.PredictedCard.setObjectName("PredictedCard")
|
||||
self.ThreeLandlordCards = QtWidgets.QLabel(Form)
|
||||
self.ThreeLandlordCards.setGeometry(QtCore.QRect(270, 20, 161, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(16)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.ThreeLandlordCards.setFont(font)
|
||||
self.ThreeLandlordCards.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.ThreeLandlordCards.setObjectName("ThreeLandlordCards")
|
||||
self.Stop = QtWidgets.QPushButton(Form)
|
||||
self.Stop.setGeometry(QtCore.QRect(230, 360, 111, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(14)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.Stop.setFont(font)
|
||||
self.Stop.setStyleSheet("")
|
||||
self.Stop.setObjectName("Stop")
|
||||
self.SwitchMode = QtWidgets.QPushButton(Form)
|
||||
self.SwitchMode.setGeometry(QtCore.QRect(370, 360, 121, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(14)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.SwitchMode.setFont(font)
|
||||
self.SwitchMode.setStyleSheet("")
|
||||
self.SwitchMode.setObjectName("SwitchMode")
|
||||
self.AutoStart = QtWidgets.QPushButton(Form)
|
||||
self.AutoStart.setGeometry(QtCore.QRect(520, 360, 111, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(14)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.AutoStart.setFont(font)
|
||||
self.AutoStart.setStyleSheet("")
|
||||
self.AutoStart.setObjectName("AutoStart")
|
||||
self.BidWinrate = QtWidgets.QLabel(Form)
|
||||
self.BidWinrate.setGeometry(QtCore.QRect(50, 220, 241, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(12)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.BidWinrate.setFont(font)
|
||||
self.BidWinrate.setObjectName("BidWinrate")
|
||||
self.PreWinrate = QtWidgets.QLabel(Form)
|
||||
self.PreWinrate.setGeometry(QtCore.QRect(50, 270, 241, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("等线")
|
||||
font.setPointSize(12)
|
||||
font.setBold(False)
|
||||
font.setItalic(False)
|
||||
font.setWeight(50)
|
||||
self.PreWinrate.setFont(font)
|
||||
self.PreWinrate.setObjectName("PreWinrate")
|
||||
|
||||
self.retranslateUi(Form)
|
||||
self.InitCard.clicked.connect(Form.init_cards)
|
||||
self.Stop.clicked.connect(Form.stop)
|
||||
self.SwitchMode.clicked.connect(Form.switch_mode)
|
||||
self.AutoStart.clicked.connect(Form.beforeStart)
|
||||
QtCore.QMetaObject.connectSlotsByName(Form)
|
||||
|
||||
def retranslateUi(self, Form):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
Form.setWindowTitle(_translate("Form", "Hi"))
|
||||
self.WinRate.setText(_translate("Form", "评分"))
|
||||
self.InitCard.setText(_translate("Form", "开始"))
|
||||
self.UserHandCards.setText(_translate("Form", "手牌"))
|
||||
self.LPlayedCard.setText(_translate("Form", "上家出牌区域"))
|
||||
self.RPlayedCard.setText(_translate("Form", "下家出牌区域"))
|
||||
self.PredictedCard.setText(_translate("Form", "AI出牌区域"))
|
||||
self.ThreeLandlordCards.setText(_translate("Form", "地主牌"))
|
||||
self.Stop.setText(_translate("Form", "停止"))
|
||||
self.SwitchMode.setText(_translate("Form", "单局"))
|
||||
self.AutoStart.setText(_translate("Form", "自动开始"))
|
||||
self.BidWinrate.setText(_translate("Form", "叫牌预估胜率:"))
|
||||
self.PreWinrate.setText(_translate("Form", "局前预估胜率:"))
|
|
@ -0,0 +1,481 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<ui version="4.0">
|
||||
<class>Form</class>
|
||||
<widget class="QWidget" name="Form">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>703</width>
|
||||
<height>421</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>Arial</family>
|
||||
<pointsize>9</pointsize>
|
||||
<weight>75</weight>
|
||||
<italic>false</italic>
|
||||
<bold>true</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="windowTitle">
|
||||
<string>Hi</string>
|
||||
</property>
|
||||
<property name="windowOpacity">
|
||||
<double>0.800000000000000</double>
|
||||
</property>
|
||||
<widget class="QLabel" name="WinRate">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>480</x>
|
||||
<y>150</y>
|
||||
<width>201</width>
|
||||
<height>61</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>14</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>评分</string>
|
||||
</property>
|
||||
<property name="alignment">
|
||||
<set>Qt::AlignCenter</set>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="InitCard">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>80</x>
|
||||
<y>360</y>
|
||||
<width>121</width>
|
||||
<height>41</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>14</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="styleSheet">
|
||||
<string notr="true"/>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>开始</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QLabel" name="UserHandCards">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>40</x>
|
||||
<y>160</y>
|
||||
<width>421</width>
|
||||
<height>41</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>14</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>手牌</string>
|
||||
</property>
|
||||
<property name="alignment">
|
||||
<set>Qt::AlignCenter</set>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QFrame" name="LPlayer">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>20</x>
|
||||
<y>80</y>
|
||||
<width>201</width>
|
||||
<height>61</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>9</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="frameShape">
|
||||
<enum>QFrame::StyledPanel</enum>
|
||||
</property>
|
||||
<property name="frameShadow">
|
||||
<enum>QFrame::Raised</enum>
|
||||
</property>
|
||||
<widget class="QLabel" name="LPlayedCard">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>201</width>
|
||||
<height>61</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>14</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>上家出牌区域</string>
|
||||
</property>
|
||||
<property name="alignment">
|
||||
<set>Qt::AlignCenter</set>
|
||||
</property>
|
||||
</widget>
|
||||
</widget>
|
||||
<widget class="QFrame" name="RPlayer">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>250</x>
|
||||
<y>80</y>
|
||||
<width>201</width>
|
||||
<height>61</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>16</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="frameShape">
|
||||
<enum>QFrame::StyledPanel</enum>
|
||||
</property>
|
||||
<property name="frameShadow">
|
||||
<enum>QFrame::Raised</enum>
|
||||
</property>
|
||||
<widget class="QLabel" name="RPlayedCard">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>201</width>
|
||||
<height>61</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>14</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>下家出牌区域</string>
|
||||
</property>
|
||||
<property name="alignment">
|
||||
<set>Qt::AlignCenter</set>
|
||||
</property>
|
||||
</widget>
|
||||
</widget>
|
||||
<widget class="QFrame" name="Player">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>480</x>
|
||||
<y>80</y>
|
||||
<width>201</width>
|
||||
<height>61</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>9</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="frameShape">
|
||||
<enum>QFrame::StyledPanel</enum>
|
||||
</property>
|
||||
<property name="frameShadow">
|
||||
<enum>QFrame::Raised</enum>
|
||||
</property>
|
||||
<widget class="QLabel" name="PredictedCard">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>201</width>
|
||||
<height>61</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>14</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="styleSheet">
|
||||
<string notr="true"/>
|
||||
</property>
|
||||
<property name="frameShape">
|
||||
<enum>QFrame::Panel</enum>
|
||||
</property>
|
||||
<property name="lineWidth">
|
||||
<number>1</number>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>AI出牌区域</string>
|
||||
</property>
|
||||
<property name="alignment">
|
||||
<set>Qt::AlignCenter</set>
|
||||
</property>
|
||||
</widget>
|
||||
</widget>
|
||||
<widget class="QLabel" name="ThreeLandlordCards">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>270</x>
|
||||
<y>20</y>
|
||||
<width>161</width>
|
||||
<height>41</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>16</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>地主牌</string>
|
||||
</property>
|
||||
<property name="alignment">
|
||||
<set>Qt::AlignCenter</set>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="Stop">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>230</x>
|
||||
<y>360</y>
|
||||
<width>111</width>
|
||||
<height>41</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>14</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="styleSheet">
|
||||
<string notr="true"/>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>停止</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="SwitchMode">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>370</x>
|
||||
<y>360</y>
|
||||
<width>121</width>
|
||||
<height>41</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>14</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="styleSheet">
|
||||
<string notr="true"/>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>单局</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QPushButton" name="AutoStart">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>520</x>
|
||||
<y>360</y>
|
||||
<width>111</width>
|
||||
<height>41</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>14</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="styleSheet">
|
||||
<string notr="true"/>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>自动开始</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QLabel" name="BidWinrate">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>50</x>
|
||||
<y>220</y>
|
||||
<width>241</width>
|
||||
<height>41</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>12</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>叫牌预估胜率:</string>
|
||||
</property>
|
||||
</widget>
|
||||
<widget class="QLabel" name="PreWinrate">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>50</x>
|
||||
<y>270</y>
|
||||
<width>241</width>
|
||||
<height>41</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="font">
|
||||
<font>
|
||||
<family>等线</family>
|
||||
<pointsize>12</pointsize>
|
||||
<weight>50</weight>
|
||||
<italic>false</italic>
|
||||
<bold>false</bold>
|
||||
</font>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string>局前预估胜率:</string>
|
||||
</property>
|
||||
</widget>
|
||||
</widget>
|
||||
<resources/>
|
||||
<connections>
|
||||
<connection>
|
||||
<sender>InitCard</sender>
|
||||
<signal>clicked()</signal>
|
||||
<receiver>Form</receiver>
|
||||
<slot>init_cards()</slot>
|
||||
<hints>
|
||||
<hint type="sourcelabel">
|
||||
<x>200</x>
|
||||
<y>360</y>
|
||||
</hint>
|
||||
<hint type="destinationlabel">
|
||||
<x>250</x>
|
||||
<y>292</y>
|
||||
</hint>
|
||||
</hints>
|
||||
</connection>
|
||||
<connection>
|
||||
<sender>Stop</sender>
|
||||
<signal>clicked()</signal>
|
||||
<receiver>Form</receiver>
|
||||
<slot>stop()</slot>
|
||||
<hints>
|
||||
<hint type="sourcelabel">
|
||||
<x>230</x>
|
||||
<y>360</y>
|
||||
</hint>
|
||||
<hint type="destinationlabel">
|
||||
<x>233</x>
|
||||
<y>220</y>
|
||||
</hint>
|
||||
</hints>
|
||||
</connection>
|
||||
<connection>
|
||||
<sender>SwitchMode</sender>
|
||||
<signal>clicked()</signal>
|
||||
<receiver>Form</receiver>
|
||||
<slot>switch_mode()</slot>
|
||||
<hints>
|
||||
<hint type="sourcelabel">
|
||||
<x>472</x>
|
||||
<y>366</y>
|
||||
</hint>
|
||||
<hint type="destinationlabel">
|
||||
<x>480</x>
|
||||
<y>277</y>
|
||||
</hint>
|
||||
</hints>
|
||||
</connection>
|
||||
<connection>
|
||||
<sender>AutoStart</sender>
|
||||
<signal>clicked()</signal>
|
||||
<receiver>Form</receiver>
|
||||
<slot>beforeStart()</slot>
|
||||
<hints>
|
||||
<hint type="sourcelabel">
|
||||
<x>613</x>
|
||||
<y>372</y>
|
||||
</hint>
|
||||
<hint type="destinationlabel">
|
||||
<x>646</x>
|
||||
<y>291</y>
|
||||
</hint>
|
||||
</hints>
|
||||
</connection>
|
||||
</connections>
|
||||
<slots>
|
||||
<slot>init_cards()</slot>
|
||||
<slot>start()</slot>
|
||||
<slot>stop()</slot>
|
||||
<slot>switch_mode()</slot>
|
||||
<slot>beforeStart()</slot>
|
||||
</slots>
|
||||
</ui>
|
|
@ -0,0 +1,147 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Form implementation generated from reading ui file 'MainWindow.ui'
|
||||
#
|
||||
# Created by: PyQt5 UI code generator 5.13.0
|
||||
#
|
||||
# WARNING! All changes made in this file will be lost!
|
||||
|
||||
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
|
||||
|
||||
class Ui_Form(object):
|
||||
def setupUi(self, Form):
|
||||
Form.setObjectName("Form")
|
||||
Form.resize(440, 450)
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Arial")
|
||||
font.setPointSize(9)
|
||||
font.setBold(True)
|
||||
font.setItalic(False)
|
||||
font.setWeight(75)
|
||||
Form.setFont(font)
|
||||
self.WinRate = QtWidgets.QLabel(Form)
|
||||
self.WinRate.setGeometry(QtCore.QRect(240, 180, 171, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setPointSize(14)
|
||||
self.WinRate.setFont(font)
|
||||
self.WinRate.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.WinRate.setObjectName("WinRate")
|
||||
self.InitCard = QtWidgets.QPushButton(Form)
|
||||
self.InitCard.setGeometry(QtCore.QRect(60, 330, 121, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Arial")
|
||||
font.setPointSize(14)
|
||||
font.setBold(True)
|
||||
font.setWeight(75)
|
||||
self.InitCard.setFont(font)
|
||||
self.InitCard.setStyleSheet("")
|
||||
self.InitCard.setObjectName("InitCard")
|
||||
|
||||
self.SwitchMode = QtWidgets.QPushButton(Form)
|
||||
self.SwitchMode.setGeometry(QtCore.QRect(60, 380, 121, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Arial")
|
||||
font.setPointSize(14)
|
||||
font.setBold(True)
|
||||
font.setWeight(75)
|
||||
self.SwitchMode.setFont(font)
|
||||
self.SwitchMode.setStyleSheet("")
|
||||
self.SwitchMode.setObjectName("SwitchMode")
|
||||
|
||||
self.AutoStart = QtWidgets.QPushButton(Form)
|
||||
self.AutoStart.setGeometry(QtCore.QRect(260, 380, 111, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Arial")
|
||||
font.setPointSize(14)
|
||||
font.setBold(True)
|
||||
font.setWeight(75)
|
||||
self.AutoStart.setFont(font)
|
||||
self.AutoStart.setStyleSheet("")
|
||||
self.AutoStart.setObjectName("AutoStart")
|
||||
|
||||
self.UserHandCards = QtWidgets.QLabel(Form)
|
||||
self.UserHandCards.setGeometry(QtCore.QRect(10, 260, 421, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setPointSize(14)
|
||||
self.UserHandCards.setFont(font)
|
||||
self.UserHandCards.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.UserHandCards.setObjectName("UserHandCards")
|
||||
self.LPlayer = QtWidgets.QFrame(Form)
|
||||
self.LPlayer.setGeometry(QtCore.QRect(10, 80, 201, 61))
|
||||
self.LPlayer.setFrameShape(QtWidgets.QFrame.StyledPanel)
|
||||
self.LPlayer.setFrameShadow(QtWidgets.QFrame.Raised)
|
||||
self.LPlayer.setObjectName("LPlayer")
|
||||
self.LPlayedCard = QtWidgets.QLabel(self.LPlayer)
|
||||
self.LPlayedCard.setGeometry(QtCore.QRect(0, 0, 201, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setPointSize(14)
|
||||
self.LPlayedCard.setFont(font)
|
||||
self.LPlayedCard.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.LPlayedCard.setObjectName("LPlayedCard")
|
||||
self.RPlayer = QtWidgets.QFrame(Form)
|
||||
self.RPlayer.setGeometry(QtCore.QRect(230, 80, 201, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setPointSize(16)
|
||||
self.RPlayer.setFont(font)
|
||||
self.RPlayer.setFrameShape(QtWidgets.QFrame.StyledPanel)
|
||||
self.RPlayer.setFrameShadow(QtWidgets.QFrame.Raised)
|
||||
self.RPlayer.setObjectName("RPlayer")
|
||||
self.RPlayedCard = QtWidgets.QLabel(self.RPlayer)
|
||||
self.RPlayedCard.setGeometry(QtCore.QRect(0, 0, 201, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setPointSize(14)
|
||||
self.RPlayedCard.setFont(font)
|
||||
self.RPlayedCard.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.RPlayedCard.setObjectName("RPlayedCard")
|
||||
self.Player = QtWidgets.QFrame(Form)
|
||||
self.Player.setGeometry(QtCore.QRect(40, 180, 171, 61))
|
||||
self.Player.setFrameShape(QtWidgets.QFrame.StyledPanel)
|
||||
self.Player.setFrameShadow(QtWidgets.QFrame.Raised)
|
||||
self.Player.setObjectName("Player")
|
||||
self.PredictedCard = QtWidgets.QLabel(self.Player)
|
||||
self.PredictedCard.setGeometry(QtCore.QRect(0, 0, 171, 61))
|
||||
font = QtGui.QFont()
|
||||
font.setPointSize(14)
|
||||
self.PredictedCard.setFont(font)
|
||||
self.PredictedCard.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.PredictedCard.setObjectName("PredictedCard")
|
||||
self.ThreeLandlordCards = QtWidgets.QLabel(Form)
|
||||
self.ThreeLandlordCards.setGeometry(QtCore.QRect(140, 10, 161, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setPointSize(16)
|
||||
self.ThreeLandlordCards.setFont(font)
|
||||
self.ThreeLandlordCards.setAlignment(QtCore.Qt.AlignCenter)
|
||||
self.ThreeLandlordCards.setObjectName("ThreeLandlordCards")
|
||||
self.Stop = QtWidgets.QPushButton(Form)
|
||||
self.Stop.setGeometry(QtCore.QRect(260, 330, 111, 41))
|
||||
font = QtGui.QFont()
|
||||
font.setFamily("Arial")
|
||||
font.setPointSize(14)
|
||||
font.setBold(True)
|
||||
font.setWeight(75)
|
||||
self.Stop.setFont(font)
|
||||
self.Stop.setStyleSheet("")
|
||||
self.Stop.setObjectName("Stop")
|
||||
|
||||
self.retranslateUi(Form)
|
||||
self.InitCard.clicked.connect(Form.init_cards)
|
||||
self.Stop.clicked.connect(Form.stop)
|
||||
self.SwitchMode.clicked.connect(Form.switch_mode)
|
||||
self.AutoStart.clicked.connect(Form.beforeStart)
|
||||
QtCore.QMetaObject.connectSlotsByName(Form)
|
||||
|
||||
def retranslateUi(self, Form):
|
||||
_translate = QtCore.QCoreApplication.translate
|
||||
Form.setWindowTitle(_translate("Form", "Hi"))
|
||||
self.WinRate.setText(_translate("Form", ""))
|
||||
self.InitCard.setText(_translate("Form", "开始"))
|
||||
self.SwitchMode.setText(_translate("Form", "单局"))
|
||||
self.AutoStart.setText(_translate("Form", "自动开始"))
|
||||
self.UserHandCards.setText(_translate("Form", "手牌"))
|
||||
self.LPlayedCard.setText(_translate("Form", "上家出牌区域"))
|
||||
self.RPlayedCard.setText(_translate("Form", "下家出牌区域"))
|
||||
self.PredictedCard.setText(_translate("Form", "AI出牌区域"))
|
||||
self.ThreeLandlordCards.setText(_translate("Form", "三张底牌"))
|
||||
self.Stop.setText(_translate("Form", "停止"))
|
|
@ -0,0 +1,21 @@
|
|||
# DouZero_For_HLDDZ_FullAuto: 将DouZero用于欢乐斗地主自动化
|
||||
* 本项目基于[DouZero](https://github.com/kwai/DouZero) 和 [DouZero_For_Happy_DouDiZhu](https://github.com/tianqiraf/DouZero_For_HappyDouDiZhu)
|
||||
* 环境配置请移步项目DouZero
|
||||
* 模型默认为ADP,更换模型请修改main.py中的模型路径
|
||||
* 运行main.py即可
|
||||
* 在原 [DouZero_For_Happy_DouDiZhu](https://github.com/tianqiraf/DouZero_For_HappyDouDiZhu) 的基础上加入了自动出牌,基于手牌自动叫牌,加倍,同时修改截屏方式为窗口区域截屏,游戏原窗口遮挡不影响游戏进行。
|
||||
* **请勿把游戏界面最小化,否则无法使用**
|
||||
|
||||
## 说明
|
||||
* 欢乐斗地主使用窗口模式运行
|
||||
* **本项目仅供学习以及技术交流,请勿用于其它目的,否则后果自负。**
|
||||
|
||||
## 使用步骤
|
||||
1. 点击游戏中开始游戏后点击程序的`自动开始`
|
||||
|
||||
## 潜在Bug
|
||||
* 有较低几率把出牌识别为不出,从而卡在自己出牌阶段。
|
||||
|
||||
|
||||
## 鸣谢
|
||||
* 本项目基于[DouZero](https://github.com/kwai/DouZero) [DouZero_For_Happy_DouDiZhu](https://github.com/tianqiraf/DouZero_For_HappyDouDiZhu)
|
|
@ -0,0 +1,2 @@
|
|||
from .dmc import train
|
||||
from .arguments import parser
|
|
@ -0,0 +1,53 @@
|
|||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='DouZero: PyTorch DouDizhu AI')
|
||||
|
||||
# General Settings
|
||||
parser.add_argument('--xpid', default='douzero',
|
||||
help='Experiment id (default: douzero)')
|
||||
parser.add_argument('--save_interval', default=30, type=int,
|
||||
help='Time interval (in minutes) at which to save the model')
|
||||
parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp'],
|
||||
help='Use ADP or WP as reward (default: ADP)')
|
||||
|
||||
# Training settings
|
||||
parser.add_argument('--gpu_devices', default='0', type=str,
|
||||
help='Which GPUs to be used for training')
|
||||
parser.add_argument('--num_actor_devices', default=1, type=int,
|
||||
help='The number of devices used for simulation')
|
||||
parser.add_argument('--num_actors', default=5, type=int,
|
||||
help='The number of actors for each simulation device')
|
||||
parser.add_argument('--training_device', default=0, type=int,
|
||||
help='The index of the GPU used for training models')
|
||||
parser.add_argument('--load_model', action='store_true',
|
||||
help='Load an existing model')
|
||||
parser.add_argument('--disable_checkpoint', action='store_true',
|
||||
help='Disable saving checkpoint')
|
||||
parser.add_argument('--savedir', default='douzero_checkpoints',
|
||||
help='Root dir where experiment data will be saved')
|
||||
|
||||
# Hyperparameters
|
||||
parser.add_argument('--total_frames', default=100000000000, type=int,
|
||||
help='Total environment frames to train for')
|
||||
parser.add_argument('--exp_epsilon', default=0.01, type=float,
|
||||
help='The probability for exploration')
|
||||
parser.add_argument('--batch_size', default=32, type=int,
|
||||
help='Learner batch size')
|
||||
parser.add_argument('--unroll_length', default=100, type=int,
|
||||
help='The unroll length (time dimension)')
|
||||
parser.add_argument('--num_buffers', default=50, type=int,
|
||||
help='Number of shared-memory buffers')
|
||||
parser.add_argument('--num_threads', default=4, type=int,
|
||||
help='Number learner threads')
|
||||
parser.add_argument('--max_grad_norm', default=40., type=float,
|
||||
help='Max norm of gradients')
|
||||
|
||||
# Optimizer settings
|
||||
parser.add_argument('--learning_rate', default=0.0001, type=float,
|
||||
help='Learning rate')
|
||||
parser.add_argument('--alpha', default=0.99, type=float,
|
||||
help='RMSProp smoothing constant')
|
||||
parser.add_argument('--momentum', default=0, type=float,
|
||||
help='RMSProp momentum')
|
||||
parser.add_argument('--epsilon', default=1e-5, type=float,
|
||||
help='RMSProp epsilon')
|
|
@ -0,0 +1,231 @@
|
|||
import os
|
||||
import threading
|
||||
import time
|
||||
import timeit
|
||||
import pprint
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import multiprocessing as mp
|
||||
from torch import nn
|
||||
|
||||
from .file_writer import FileWriter
|
||||
from .models import Model
|
||||
from .utils import get_batch, log, create_env, create_buffers, create_optimizers, act
|
||||
|
||||
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_down']}
|
||||
|
||||
def compute_loss(logits, targets):
|
||||
loss = ((logits.squeeze(-1) - targets)**2).mean()
|
||||
return loss
|
||||
|
||||
def learn(position,
|
||||
actor_models,
|
||||
model,
|
||||
batch,
|
||||
optimizer,
|
||||
flags,
|
||||
lock):
|
||||
"""Performs a learning (optimization) step."""
|
||||
device = torch.device('cuda:'+str(flags.training_device))
|
||||
obs_x_no_action = batch['obs_x_no_action'].to(device)
|
||||
obs_action = batch['obs_action'].to(device)
|
||||
obs_x = torch.cat((obs_x_no_action, obs_action), dim=2).float()
|
||||
obs_x = torch.flatten(obs_x, 0, 1)
|
||||
obs_z = torch.flatten(batch['obs_z'].to(device), 0, 1).float()
|
||||
target = torch.flatten(batch['target'].to(device), 0, 1)
|
||||
episode_returns = batch['episode_return'][batch['done']]
|
||||
mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device))
|
||||
|
||||
with lock:
|
||||
learner_outputs = model(obs_z, obs_x, return_value=True)
|
||||
loss = compute_loss(learner_outputs['values'], target)
|
||||
stats = {
|
||||
'mean_episode_return_'+position: torch.mean(torch.stack([_r for _r in mean_episode_return_buf[position]])).item(),
|
||||
'loss_'+position: loss.item(),
|
||||
}
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
|
||||
optimizer.step()
|
||||
|
||||
for actor_model in actor_models:
|
||||
actor_model.get_model(position).load_state_dict(model.state_dict())
|
||||
return stats
|
||||
|
||||
def train(flags):
|
||||
"""
|
||||
This is the main funtion for training. It will first
|
||||
initilize everything, such as buffers, optimizers, etc.
|
||||
Then it will start subprocesses as actors. Then, it will call
|
||||
learning function with multiple threads.
|
||||
"""
|
||||
plogger = FileWriter(
|
||||
xpid=flags.xpid,
|
||||
xp_args=flags.__dict__,
|
||||
rootdir=flags.savedir,
|
||||
)
|
||||
checkpointpath = os.path.expandvars(
|
||||
os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar')))
|
||||
|
||||
T = flags.unroll_length
|
||||
B = flags.batch_size
|
||||
|
||||
# Initialize actor models
|
||||
models = []
|
||||
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
||||
for device in range(flags.num_actor_devices):
|
||||
model = Model(device=device)
|
||||
model.share_memory()
|
||||
model.eval()
|
||||
models.append(model)
|
||||
|
||||
# Initialize buffers
|
||||
buffers = create_buffers(flags)
|
||||
|
||||
# Initialize queues
|
||||
actor_processes = []
|
||||
ctx = mp.get_context('spawn')
|
||||
free_queue = []
|
||||
full_queue = []
|
||||
for device in range(flags.num_actor_devices):
|
||||
_free_queue = {'landlord': ctx.SimpleQueue(), 'landlord_up': ctx.SimpleQueue(), 'landlord_down': ctx.SimpleQueue()}
|
||||
_full_queue = {'landlord': ctx.SimpleQueue(), 'landlord_up': ctx.SimpleQueue(), 'landlord_down': ctx.SimpleQueue()}
|
||||
free_queue.append(_free_queue)
|
||||
full_queue.append(_full_queue)
|
||||
|
||||
# Learner model for training
|
||||
learner_model = Model(device=flags.training_device)
|
||||
|
||||
# Create optimizers
|
||||
optimizers = create_optimizers(flags, learner_model)
|
||||
|
||||
# Stat Keys
|
||||
stat_keys = [
|
||||
'mean_episode_return_landlord',
|
||||
'loss_landlord',
|
||||
'mean_episode_return_landlord_up',
|
||||
'loss_landlord_up',
|
||||
'mean_episode_return_landlord_down',
|
||||
'loss_landlord_down',
|
||||
]
|
||||
frames, stats = 0, {k: 0 for k in stat_keys}
|
||||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_down':0}
|
||||
|
||||
# Load models if any
|
||||
if flags.load_model and os.path.exists(checkpointpath):
|
||||
checkpoint_states = torch.load(
|
||||
checkpointpath, map_location="cuda:"+str(flags.training_device)
|
||||
)
|
||||
for k in ['landlord', 'landlord_up', 'landlord_down']:
|
||||
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
||||
for device in range(flags.num_actor_devices):
|
||||
models[device].get_model(k).load_state_dict(learner_model.get_model(k).state_dict())
|
||||
stats = checkpoint_states["stats"]
|
||||
frames = checkpoint_states["frames"]
|
||||
position_frames = checkpoint_states["position_frames"]
|
||||
log.info(f"Resuming preempted job, current stats:\n{stats}")
|
||||
|
||||
# Starting actor processes
|
||||
for device in range(flags.num_actor_devices):
|
||||
num_actors = flags.num_actors
|
||||
for i in range(flags.num_actors):
|
||||
actor = ctx.Process(
|
||||
target=act,
|
||||
args=(i, device, free_queue[device], full_queue[device], models[device], buffers[device], flags))
|
||||
actor.start()
|
||||
actor_processes.append(actor)
|
||||
|
||||
def batch_and_learn(i, device, position, local_lock, position_lock, lock=threading.Lock()):
|
||||
"""Thread target for the learning process."""
|
||||
nonlocal frames, position_frames, stats
|
||||
while frames < flags.total_frames:
|
||||
batch = get_batch(free_queue[device][position], full_queue[device][position], buffers[device][position], flags, local_lock)
|
||||
_stats = learn(position, models, learner_model.get_model(position), batch,
|
||||
optimizers[position], flags, position_lock)
|
||||
|
||||
with lock:
|
||||
for k in _stats:
|
||||
stats[k] = _stats[k]
|
||||
to_log = dict(frames=frames)
|
||||
to_log.update({k: stats[k] for k in stat_keys})
|
||||
plogger.log(to_log)
|
||||
frames += T * B
|
||||
position_frames[position] += T * B
|
||||
|
||||
for device in range(flags.num_actor_devices):
|
||||
for m in range(flags.num_buffers):
|
||||
free_queue[device]['landlord'].put(m)
|
||||
free_queue[device]['landlord_up'].put(m)
|
||||
free_queue[device]['landlord_down'].put(m)
|
||||
|
||||
threads = []
|
||||
locks = [{'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_down': threading.Lock()} for _ in range(flags.num_actor_devices)]
|
||||
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||
|
||||
for device in range(flags.num_actor_devices):
|
||||
for i in range(flags.num_threads):
|
||||
for position in ['landlord', 'landlord_up', 'landlord_down']:
|
||||
thread = threading.Thread(
|
||||
target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,device,position,locks[device][position],position_locks[position]))
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
def checkpoint(frames):
|
||||
if flags.disable_checkpoint:
|
||||
return
|
||||
log.info('Saving checkpoint to %s', checkpointpath)
|
||||
_models = learner_model.get_models()
|
||||
torch.save({
|
||||
'model_state_dict': {k: _models[k].state_dict() for k in _models},
|
||||
'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers},
|
||||
"stats": stats,
|
||||
'flags': vars(flags),
|
||||
'frames': frames,
|
||||
'position_frames': position_frames
|
||||
}, checkpointpath)
|
||||
|
||||
# Save the weights for evaluation purpose
|
||||
for position in ['landlord', 'landlord_up', 'landlord_down']:
|
||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||
'%s/%s/%s' % (flags.savedir, flags.xpid, position+'_weights_'+str(frames)+'.ckpt')))
|
||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||
|
||||
timer = timeit.default_timer
|
||||
try:
|
||||
last_checkpoint_time = timer() - flags.save_interval * 60
|
||||
while frames < flags.total_frames:
|
||||
start_frames = frames
|
||||
position_start_frames = {k: position_frames[k] for k in position_frames}
|
||||
start_time = timer()
|
||||
time.sleep(5)
|
||||
|
||||
if timer() - last_checkpoint_time > flags.save_interval * 60:
|
||||
checkpoint(frames)
|
||||
last_checkpoint_time = timer()
|
||||
|
||||
end_time = timer()
|
||||
fps = (frames - start_frames) / (end_time - start_time)
|
||||
position_fps = {k:(position_frames[k]-position_start_frames[k])/(end_time-start_time) for k in position_frames}
|
||||
log.info('After %i (L:%i U:%i D:%i) frames: @ %.1f fps (L:%.1f U:%.1f D:%.1f) Stats:\n%s',
|
||||
frames,
|
||||
position_frames['landlord'],
|
||||
position_frames['landlord_up'],
|
||||
position_frames['landlord_down'],
|
||||
fps,
|
||||
position_fps['landlord'],
|
||||
position_fps['landlord_up'],
|
||||
position_fps['landlord_down'],
|
||||
pprint.pformat(stats))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
return
|
||||
else:
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
log.info('Learning finished after %d frames.', frames)
|
||||
|
||||
checkpoint(frames)
|
||||
plogger.close()
|
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
Here, we wrap the original environment to make it easier
|
||||
to use. When a game is finished, instead of mannualy reseting
|
||||
the environment, we do it automatically.
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def _format_observation(obs, device):
|
||||
"""
|
||||
A utility function to process observations and
|
||||
move them to CUDA.
|
||||
"""
|
||||
position = obs['position']
|
||||
device = torch.device('cuda:'+str(device))
|
||||
x_batch = torch.from_numpy(obs['x_batch']).to(device)
|
||||
z_batch = torch.from_numpy(obs['z_batch']).to(device)
|
||||
x_no_action = torch.from_numpy(obs['x_no_action'])
|
||||
z = torch.from_numpy(obs['z'])
|
||||
obs = {'x_batch': x_batch,
|
||||
'z_batch': z_batch,
|
||||
'legal_actions': obs['legal_actions'],
|
||||
}
|
||||
return position, obs, x_no_action, z
|
||||
|
||||
class Environment:
|
||||
def __init__(self, env, device):
|
||||
""" Initialzie this environment wrapper
|
||||
"""
|
||||
self.env = env
|
||||
self.device = device
|
||||
self.episode_return = None
|
||||
|
||||
def initial(self):
|
||||
initial_position, initial_obs, x_no_action, z = _format_observation(self.env.reset(), self.device)
|
||||
initial_reward = torch.zeros(1, 1)
|
||||
self.episode_return = torch.zeros(1, 1)
|
||||
initial_done = torch.ones(1, 1, dtype=torch.bool)
|
||||
|
||||
return initial_position, initial_obs, dict(
|
||||
done=initial_done,
|
||||
episode_return=self.episode_return,
|
||||
obs_x_no_action=x_no_action,
|
||||
obs_z=z,
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, _ = self.env.step(action)
|
||||
|
||||
self.episode_return += reward
|
||||
episode_return = self.episode_return
|
||||
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
self.episode_return = torch.zeros(1, 1)
|
||||
|
||||
position, obs, x_no_action, z = _format_observation(obs, self.device)
|
||||
reward = torch.tensor(reward).view(1, 1)
|
||||
done = torch.tensor(done).view(1, 1)
|
||||
|
||||
return position, obs, dict(
|
||||
done=done,
|
||||
episode_return=episode_return,
|
||||
obs_x_no_action=x_no_action,
|
||||
obs_z=z,
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.env.close()
|
|
@ -0,0 +1,187 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
# import git
|
||||
|
||||
|
||||
# def gather_metadata() -> Dict:
|
||||
# date_start = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
|
||||
# # gathering git metadata
|
||||
# try:
|
||||
# repo = git.Repo(search_parent_directories=True)
|
||||
# git_sha = repo.commit().hexsha
|
||||
# git_data = dict(
|
||||
# commit=git_sha,
|
||||
# branch=repo.active_branch.name,
|
||||
# is_dirty=repo.is_dirty(),
|
||||
# path=repo.git_dir,
|
||||
# )
|
||||
# except git.InvalidGitRepositoryError:
|
||||
# git_data = None
|
||||
# # gathering slurm metadata
|
||||
# if 'SLURM_JOB_ID' in os.environ:
|
||||
# slurm_env_keys = [k for k in os.environ if k.startswith('SLURM')]
|
||||
# slurm_data = {}
|
||||
# for k in slurm_env_keys:
|
||||
# d_key = k.replace('SLURM_', '').replace('SLURMD_', '').lower()
|
||||
# slurm_data[d_key] = os.environ[k]
|
||||
# else:
|
||||
# slurm_data = None
|
||||
# return dict(
|
||||
# date_start=date_start,
|
||||
# date_end=None,
|
||||
# successful=False,
|
||||
# git=git_data,
|
||||
# slurm=slurm_data,
|
||||
# env=os.environ.copy(),
|
||||
# )
|
||||
|
||||
|
||||
class FileWriter:
|
||||
def __init__(self,
|
||||
xpid: str = None,
|
||||
xp_args: dict = None,
|
||||
rootdir: str = '~/palaas'):
|
||||
if not xpid:
|
||||
# make unique id
|
||||
xpid = '{proc}_{unixtime}'.format(
|
||||
proc=os.getpid(), unixtime=int(time.time()))
|
||||
self.xpid = xpid
|
||||
self._tick = 0
|
||||
|
||||
# metadata gathering
|
||||
if xp_args is None:
|
||||
xp_args = {}
|
||||
self.metadata = gather_metadata()
|
||||
# we need to copy the args, otherwise when we close the file writer
|
||||
# (and rewrite the args) we might have non-serializable objects (or
|
||||
# other nasty stuff).
|
||||
self.metadata['args'] = copy.deepcopy(xp_args)
|
||||
self.metadata['xpid'] = self.xpid
|
||||
|
||||
formatter = logging.Formatter('%(message)s')
|
||||
self._logger = logging.getLogger('palaas/out')
|
||||
|
||||
# to stdout handler
|
||||
shandle = logging.StreamHandler()
|
||||
shandle.setFormatter(formatter)
|
||||
self._logger.addHandler(shandle)
|
||||
self._logger.setLevel(logging.INFO)
|
||||
|
||||
rootdir = os.path.expandvars(os.path.expanduser(rootdir))
|
||||
# to file handler
|
||||
self.basepath = os.path.join(rootdir, self.xpid)
|
||||
|
||||
if not os.path.exists(self.basepath):
|
||||
self._logger.info('Creating log directory: %s', self.basepath)
|
||||
os.makedirs(self.basepath, exist_ok=True)
|
||||
else:
|
||||
self._logger.info('Found log directory: %s', self.basepath)
|
||||
|
||||
# NOTE: remove latest because it creates errors when running on slurm
|
||||
# multiple jobs trying to write to latest but cannot find it
|
||||
# Add 'latest' as symlink unless it exists and is no symlink.
|
||||
# symlink = os.path.join(rootdir, 'latest')
|
||||
# if os.path.islink(symlink):
|
||||
# os.remove(symlink)
|
||||
# if not os.path.exists(symlink):
|
||||
# os.symlink(self.basepath, symlink)
|
||||
# self._logger.info('Symlinked log directory: %s', symlink)
|
||||
|
||||
self.paths = dict(
|
||||
msg='{base}/out.log'.format(base=self.basepath),
|
||||
logs='{base}/logs.csv'.format(base=self.basepath),
|
||||
fields='{base}/fields.csv'.format(base=self.basepath),
|
||||
meta='{base}/meta.json'.format(base=self.basepath),
|
||||
)
|
||||
|
||||
self._logger.info('Saving arguments to %s', self.paths['meta'])
|
||||
if os.path.exists(self.paths['meta']):
|
||||
self._logger.warning('Path to meta file already exists. '
|
||||
'Not overriding meta.')
|
||||
else:
|
||||
self._save_metadata()
|
||||
|
||||
self._logger.info('Saving messages to %s', self.paths['msg'])
|
||||
if os.path.exists(self.paths['msg']):
|
||||
self._logger.warning('Path to message file already exists. '
|
||||
'New data will be appended.')
|
||||
|
||||
fhandle = logging.FileHandler(self.paths['msg'])
|
||||
fhandle.setFormatter(formatter)
|
||||
self._logger.addHandler(fhandle)
|
||||
|
||||
self._logger.info('Saving logs data to %s', self.paths['logs'])
|
||||
self._logger.info('Saving logs\' fields to %s', self.paths['fields'])
|
||||
if os.path.exists(self.paths['logs']):
|
||||
self._logger.warning('Path to log file already exists. '
|
||||
'New data will be appended.')
|
||||
with open(self.paths['fields'], 'r') as csvfile:
|
||||
reader = csv.reader(csvfile)
|
||||
self.fieldnames = list(reader)[0]
|
||||
else:
|
||||
self.fieldnames = ['_tick', '_time']
|
||||
|
||||
def log(self, to_log: Dict, tick: int = None,
|
||||
verbose: bool = False) -> None:
|
||||
if tick is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
to_log['_tick'] = self._tick
|
||||
self._tick += 1
|
||||
to_log['_time'] = time.time()
|
||||
|
||||
old_len = len(self.fieldnames)
|
||||
for k in to_log:
|
||||
if k not in self.fieldnames:
|
||||
self.fieldnames.append(k)
|
||||
if old_len != len(self.fieldnames):
|
||||
with open(self.paths['fields'], 'w') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(self.fieldnames)
|
||||
self._logger.info('Updated log fields: %s', self.fieldnames)
|
||||
|
||||
if to_log['_tick'] == 0:
|
||||
# print("\ncreating logs file ")
|
||||
with open(self.paths['logs'], 'a') as f:
|
||||
f.write('# %s\n' % ','.join(self.fieldnames))
|
||||
|
||||
if verbose:
|
||||
self._logger.info('LOG | %s', ', '.join(
|
||||
['{}: {}'.format(k, to_log[k]) for k in sorted(to_log)]))
|
||||
|
||||
with open(self.paths['logs'], 'a') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
|
||||
writer.writerow(to_log)
|
||||
# print("\nadded to log file")
|
||||
|
||||
def close(self, successful: bool = True) -> None:
|
||||
self.metadata['date_end'] = datetime.datetime.now().strftime(
|
||||
'%Y-%m-%d %H:%M:%S.%f')
|
||||
self.metadata['successful'] = successful
|
||||
self._save_metadata()
|
||||
|
||||
def _save_metadata(self) -> None:
|
||||
with open(self.paths['meta'], 'w') as jsonfile:
|
||||
json.dump(self.metadata, jsonfile, indent=4, sort_keys=True)
|
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
This file includes the torch models. We wrap the three
|
||||
models into one class for convenience.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class LandlordLstmModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(162, 128, batch_first=True)
|
||||
self.dense1 = nn.Linear(373 + 128, 512)
|
||||
self.dense2 = nn.Linear(512, 512)
|
||||
self.dense3 = nn.Linear(512, 512)
|
||||
self.dense4 = nn.Linear(512, 512)
|
||||
self.dense5 = nn.Linear(512, 512)
|
||||
self.dense6 = nn.Linear(512, 1)
|
||||
|
||||
def forward(self, z, x, return_value=False, flags=None):
|
||||
lstm_out, (h_n, _) = self.lstm(z)
|
||||
lstm_out = lstm_out[:,-1,:]
|
||||
x = torch.cat([lstm_out,x], dim=-1)
|
||||
x = self.dense1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense3(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense4(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense5(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense6(x)
|
||||
if return_value:
|
||||
return dict(values=x)
|
||||
else:
|
||||
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
|
||||
action = torch.randint(x.shape[0], (1,))[0]
|
||||
else:
|
||||
action = torch.argmax(x,dim=0)[0]
|
||||
return dict(action=action)
|
||||
|
||||
class FarmerLstmModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(162, 128, batch_first=True)
|
||||
self.dense1 = nn.Linear(484 + 128, 512)
|
||||
self.dense2 = nn.Linear(512, 512)
|
||||
self.dense3 = nn.Linear(512, 512)
|
||||
self.dense4 = nn.Linear(512, 512)
|
||||
self.dense5 = nn.Linear(512, 512)
|
||||
self.dense6 = nn.Linear(512, 1)
|
||||
|
||||
def forward(self, z, x, return_value=False, flags=None):
|
||||
lstm_out, (h_n, _) = self.lstm(z)
|
||||
lstm_out = lstm_out[:,-1,:]
|
||||
x = torch.cat([lstm_out,x], dim=-1)
|
||||
x = self.dense1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense3(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense4(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense5(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense6(x)
|
||||
if return_value:
|
||||
return dict(values=x)
|
||||
else:
|
||||
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
|
||||
action = torch.randint(x.shape[0], (1,))[0]
|
||||
else:
|
||||
action = torch.argmax(x,dim=0)[0]
|
||||
return dict(action=action)
|
||||
|
||||
# Model dict is only used in evaluation but not training
|
||||
model_dict = {}
|
||||
model_dict['landlord'] = LandlordLstmModel
|
||||
model_dict['landlord_up'] = FarmerLstmModel
|
||||
model_dict['landlord_down'] = FarmerLstmModel
|
||||
|
||||
class Model:
|
||||
"""
|
||||
The wrapper for the three models. We also wrap several
|
||||
interfaces such as share_memory, eval, etc.
|
||||
"""
|
||||
def __init__(self, device=0):
|
||||
self.models = {}
|
||||
self.models['landlord'] = LandlordLstmModel().to(torch.device('cuda:'+str(device)))
|
||||
self.models['landlord_up'] = FarmerLstmModel().to(torch.device('cuda:'+str(device)))
|
||||
self.models['landlord_down'] = FarmerLstmModel().to(torch.device('cuda:'+str(device)))
|
||||
|
||||
def forward(self, position, z, x, training=False, flags=None):
|
||||
model = self.models[position]
|
||||
return model.forward(z, x, training, flags)
|
||||
|
||||
def share_memory(self):
|
||||
self.models['landlord'].share_memory()
|
||||
self.models['landlord_up'].share_memory()
|
||||
self.models['landlord_down'].share_memory()
|
||||
|
||||
def eval(self):
|
||||
self.models['landlord'].eval()
|
||||
self.models['landlord_up'].eval()
|
||||
self.models['landlord_down'].eval()
|
||||
|
||||
def parameters(self, position):
|
||||
return self.models[position].parameters()
|
||||
|
||||
def get_model(self, position):
|
||||
return self.models[position]
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
|
@ -0,0 +1,205 @@
|
|||
import os
|
||||
import typing
|
||||
import logging
|
||||
import traceback
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch import multiprocessing as mp
|
||||
|
||||
from .env_utils import Environment
|
||||
from douzero.env import Env
|
||||
|
||||
Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7,
|
||||
11: 8, 12: 9, 13: 10, 14: 11, 17: 12}
|
||||
|
||||
NumOnes2Array = {0: np.array([0, 0, 0, 0]),
|
||||
1: np.array([1, 0, 0, 0]),
|
||||
2: np.array([1, 1, 0, 0]),
|
||||
3: np.array([1, 1, 1, 0]),
|
||||
4: np.array([1, 1, 1, 1])}
|
||||
|
||||
shandle = logging.StreamHandler()
|
||||
shandle.setFormatter(
|
||||
logging.Formatter(
|
||||
'[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] '
|
||||
'%(message)s'))
|
||||
log = logging.getLogger('doudzero')
|
||||
log.propagate = False
|
||||
log.addHandler(shandle)
|
||||
log.setLevel(logging.INFO)
|
||||
|
||||
# Buffers are used to transfer data between actor processes
|
||||
# and learner processes. They are shared tensors in GPU
|
||||
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
||||
|
||||
def create_env(flags):
|
||||
return Env(flags.objective)
|
||||
|
||||
def get_batch(free_queue,
|
||||
full_queue,
|
||||
buffers,
|
||||
flags,
|
||||
lock):
|
||||
"""
|
||||
This function will sample a batch from the buffers based
|
||||
on the indices received from the full queue. It will also
|
||||
free the indices by sending it to full_queue.
|
||||
"""
|
||||
with lock:
|
||||
indices = [full_queue.get() for _ in range(flags.batch_size)]
|
||||
batch = {
|
||||
key: torch.stack([buffers[key][m] for m in indices], dim=1)
|
||||
for key in buffers
|
||||
}
|
||||
for m in indices:
|
||||
free_queue.put(m)
|
||||
return batch
|
||||
|
||||
def create_optimizers(flags, learner_model):
|
||||
"""
|
||||
Create three optimizers for the three positions
|
||||
"""
|
||||
positions = ['landlord', 'landlord_up', 'landlord_down']
|
||||
optimizers = {}
|
||||
for position in positions:
|
||||
optimizer = torch.optim.RMSprop(
|
||||
learner_model.parameters(position),
|
||||
lr=flags.learning_rate,
|
||||
momentum=flags.momentum,
|
||||
eps=flags.epsilon,
|
||||
alpha=flags.alpha)
|
||||
optimizers[position] = optimizer
|
||||
return optimizers
|
||||
|
||||
def create_buffers(flags):
|
||||
"""
|
||||
We create buffers for different positions as well as
|
||||
for different devices (i.e., GPU). That is, each device
|
||||
will have three buffers for the three positions.
|
||||
"""
|
||||
T = flags.unroll_length
|
||||
positions = ['landlord', 'landlord_up', 'landlord_down']
|
||||
buffers = []
|
||||
for device in range(torch.cuda.device_count()):
|
||||
buffers.append({})
|
||||
for position in positions:
|
||||
x_dim = 319 if position == 'landlord' else 430
|
||||
specs = dict(
|
||||
done=dict(size=(T,), dtype=torch.bool),
|
||||
episode_return=dict(size=(T,), dtype=torch.float32),
|
||||
target=dict(size=(T,), dtype=torch.float32),
|
||||
obs_x_no_action=dict(size=(T, x_dim), dtype=torch.int8),
|
||||
obs_action=dict(size=(T, 54), dtype=torch.int8),
|
||||
obs_z=dict(size=(T, 5, 162), dtype=torch.int8),
|
||||
)
|
||||
_buffers: Buffers = {key: [] for key in specs}
|
||||
for _ in range(flags.num_buffers):
|
||||
for key in _buffers:
|
||||
_buffer = torch.empty(**specs[key]).to(torch.device('cuda:'+str(device))).share_memory_()
|
||||
_buffers[key].append(_buffer)
|
||||
buffers[device][position] = _buffers
|
||||
return buffers
|
||||
|
||||
def act(i, device, free_queue, full_queue, model, buffers, flags):
|
||||
"""
|
||||
This function will run forever until we stop it. It will generate
|
||||
data from the environment and send the data to buffer. It uses
|
||||
a free queue and full queue to syncup with the main process.
|
||||
"""
|
||||
positions = ['landlord', 'landlord_up', 'landlord_down']
|
||||
try:
|
||||
T = flags.unroll_length
|
||||
log.info('Device %i Actor %i started.', device, i)
|
||||
|
||||
env = create_env(flags)
|
||||
|
||||
env = Environment(env, device)
|
||||
|
||||
done_buf = {p: [] for p in positions}
|
||||
episode_return_buf = {p: [] for p in positions}
|
||||
target_buf = {p: [] for p in positions}
|
||||
obs_x_no_action_buf = {p: [] for p in positions}
|
||||
obs_action_buf = {p: [] for p in positions}
|
||||
obs_z_buf = {p: [] for p in positions}
|
||||
size = {p: 0 for p in positions}
|
||||
|
||||
position, obs, env_output = env.initial()
|
||||
|
||||
while True:
|
||||
while True:
|
||||
obs_x_no_action_buf[position].append(env_output['obs_x_no_action'])
|
||||
obs_z_buf[position].append(env_output['obs_z'])
|
||||
with torch.no_grad():
|
||||
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
||||
_action_idx = int(agent_output['action'].cpu().detach().numpy())
|
||||
action = obs['legal_actions'][_action_idx]
|
||||
obs_action_buf[position].append(_cards2tensor(action))
|
||||
position, obs, env_output = env.step(action)
|
||||
size[position] += 1
|
||||
if env_output['done']:
|
||||
for p in positions:
|
||||
diff = size[p] - len(target_buf[p])
|
||||
if diff > 0:
|
||||
done_buf[p].extend([False for _ in range(diff-1)])
|
||||
done_buf[p].append(True)
|
||||
|
||||
episode_return = env_output['episode_return'] if p == 'landlord' else -env_output['episode_return']
|
||||
episode_return_buf[p].extend([0.0 for _ in range(diff-1)])
|
||||
episode_return_buf[p].append(episode_return)
|
||||
target_buf[p].extend([episode_return for _ in range(diff)])
|
||||
break
|
||||
|
||||
for p in positions:
|
||||
if size[p] > T:
|
||||
index = free_queue[p].get()
|
||||
if index is None:
|
||||
break
|
||||
for t in range(T):
|
||||
buffers[p]['done'][index][t, ...] = done_buf[p][t]
|
||||
buffers[p]['episode_return'][index][t, ...] = episode_return_buf[p][t]
|
||||
buffers[p]['target'][index][t, ...] = target_buf[p][t]
|
||||
buffers[p]['obs_x_no_action'][index][t, ...] = obs_x_no_action_buf[p][t]
|
||||
buffers[p]['obs_action'][index][t, ...] = obs_action_buf[p][t]
|
||||
buffers[p]['obs_z'][index][t, ...] = obs_z_buf[p][t]
|
||||
full_queue[p].put(index)
|
||||
done_buf[p] = done_buf[p][T:]
|
||||
episode_return_buf[p] = episode_return_buf[p][T:]
|
||||
target_buf[p] = target_buf[p][T:]
|
||||
obs_x_no_action_buf[p] = obs_x_no_action_buf[p][T:]
|
||||
obs_action_buf[p] = obs_action_buf[p][T:]
|
||||
obs_z_buf[p] = obs_z_buf[p][T:]
|
||||
size[p] -= T
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.error('Exception in worker process %i', i)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
raise e
|
||||
|
||||
def _cards2tensor(list_cards):
|
||||
"""
|
||||
Convert a list of integers to the tensor
|
||||
representation
|
||||
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
|
||||
"""
|
||||
if len(list_cards) == 0:
|
||||
return torch.zeros(54, dtype=torch.int8)
|
||||
|
||||
matrix = np.zeros([4, 13], dtype=np.int8)
|
||||
jokers = np.zeros(2, dtype=np.int8)
|
||||
counter = Counter(list_cards)
|
||||
for card, num_times in counter.items():
|
||||
if card < 20:
|
||||
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
||||
elif card == 20:
|
||||
jokers[0] = 1
|
||||
elif card == 30:
|
||||
jokers[1] = 1
|
||||
matrix = np.concatenate((matrix.flatten('F'), jokers))
|
||||
matrix = torch.from_numpy(matrix)
|
||||
return matrix
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from douzero.env.env import get_obs
|
||||
|
||||
def _load_model(position, model_path):
|
||||
from douzero.dmc.models import model_dict
|
||||
model = model_dict[position]()
|
||||
model_state_dict = model.state_dict()
|
||||
if torch.cuda.is_available():
|
||||
pretrained = torch.load(model_path, map_location='cuda:0')
|
||||
else:
|
||||
pretrained = torch.load(model_path, map_location='cpu')
|
||||
pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict}
|
||||
model_state_dict.update(pretrained)
|
||||
model.load_state_dict(model_state_dict)
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
class DeepAgent:
|
||||
|
||||
def __init__(self, position, model_path):
|
||||
self.model = _load_model(position, model_path)
|
||||
|
||||
def act(self, infoset):
|
||||
# 只有一个合法动作时直接返回,这样会得不到胜率信息
|
||||
# if len(infoset.legal_actions) == 1:
|
||||
# return infoset.legal_actions[0], 0
|
||||
|
||||
obs = get_obs(infoset)
|
||||
z_batch = torch.from_numpy(obs['z_batch']).float()
|
||||
x_batch = torch.from_numpy(obs['x_batch']).float()
|
||||
if torch.cuda.is_available():
|
||||
z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
|
||||
y_pred = self.model.forward(z_batch, x_batch, return_value=True)['values']
|
||||
y_pred = y_pred.detach().cpu().numpy()
|
||||
|
||||
best_action_index = np.argmax(y_pred, axis=0)[0]
|
||||
best_action = infoset.legal_actions[best_action_index]
|
||||
best_action_confidence = y_pred[best_action_index]
|
||||
# print(best_action, best_action_confidence, y_pred)
|
||||
return best_action, best_action_confidence
|
|
@ -0,0 +1,9 @@
|
|||
import random
|
||||
|
||||
class RandomAgent():
|
||||
|
||||
def __init__(self):
|
||||
self.name = 'Random'
|
||||
|
||||
def act(self, infoset):
|
||||
return random.choice(infoset.legal_actions)
|
|
@ -0,0 +1,183 @@
|
|||
import random
|
||||
|
||||
from rlcard.games.doudizhu.utils import CARD_TYPE
|
||||
|
||||
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||
13: 'K', 14: 'A', 17: '2', 20: 'B', 30: 'R'}
|
||||
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
|
||||
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
|
||||
'K': 13, 'A': 14, '2': 17, 'B': 20, 'R': 30}
|
||||
|
||||
INDEX = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4,
|
||||
'8': 5, '9': 6, 'T': 7, 'J': 8, 'Q': 9,
|
||||
'K': 10, 'A': 11, '2': 12, 'B': 13, 'R': 14}
|
||||
|
||||
class RLCardAgent(object):
|
||||
|
||||
def __init__(self, position):
|
||||
self.name = 'RLCard'
|
||||
self.position = position
|
||||
|
||||
def act(self, infoset):
|
||||
try:
|
||||
# Hand cards
|
||||
hand_cards = infoset.player_hand_cards
|
||||
for i, c in enumerate(hand_cards):
|
||||
hand_cards[i] = EnvCard2RealCard[c]
|
||||
hand_cards = ''.join(hand_cards)
|
||||
|
||||
# Last move
|
||||
last_move = infoset.last_move.copy()
|
||||
for i, c in enumerate(last_move):
|
||||
last_move[i] = EnvCard2RealCard[c]
|
||||
last_move = ''.join(last_move)
|
||||
|
||||
# Last two moves
|
||||
last_two_cards = infoset.last_two_moves
|
||||
for i in range(2):
|
||||
for j, c in enumerate(last_two_cards[i]):
|
||||
last_two_cards[i][j] = EnvCard2RealCard[c]
|
||||
last_two_cards[i] = ''.join(last_two_cards[i])
|
||||
|
||||
# Last pid
|
||||
last_pid = infoset.last_pid
|
||||
|
||||
action = None
|
||||
# the rule of leading round
|
||||
if last_two_cards[0] == '' and last_two_cards[1] == '':
|
||||
chosen_action = None
|
||||
comb = combine_cards(hand_cards)
|
||||
min_card = hand_cards[0]
|
||||
for _, acs in comb.items():
|
||||
for ac in acs:
|
||||
if min_card in ac:
|
||||
chosen_action = ac
|
||||
action = [char for char in chosen_action]
|
||||
for i, c in enumerate(action):
|
||||
action[i] = RealCard2EnvCard[c]
|
||||
#print('lead action:', action)
|
||||
# the rule of following cards
|
||||
else:
|
||||
the_type = CARD_TYPE[0][last_move][0][0]
|
||||
chosen_action = ''
|
||||
rank = 1000
|
||||
for ac in infoset.legal_actions:
|
||||
_ac = ac.copy()
|
||||
for i, c in enumerate(_ac):
|
||||
_ac[i] = EnvCard2RealCard[c]
|
||||
_ac = ''.join(_ac)
|
||||
if _ac != '' and the_type == CARD_TYPE[0][_ac][0][0]:
|
||||
if int(CARD_TYPE[0][_ac][0][1]) < rank:
|
||||
rank = int(CARD_TYPE[0][_ac][0][1])
|
||||
chosen_action = _ac
|
||||
if chosen_action != '':
|
||||
action = [char for char in chosen_action]
|
||||
for i, c in enumerate(action):
|
||||
action[i] = RealCard2EnvCard[c]
|
||||
#print('action:', action)
|
||||
elif last_pid != 'landlord' and self.position != 'landlord':
|
||||
action = []
|
||||
|
||||
if action is None:
|
||||
action = random.choice(infoset.legal_actions)
|
||||
except:
|
||||
action = random.choice(infoset.legal_actions)
|
||||
#import traceback
|
||||
#traceback.print_exc()
|
||||
|
||||
assert action in infoset.legal_actions
|
||||
|
||||
return action
|
||||
|
||||
def card_str2list(hand):
|
||||
hand_list = [0 for _ in range(15)]
|
||||
for card in hand:
|
||||
hand_list[INDEX[card]] += 1
|
||||
return hand_list
|
||||
|
||||
def list2card_str(hand_list):
|
||||
card_str = ''
|
||||
cards = [card for card in INDEX]
|
||||
for index, count in enumerate(hand_list):
|
||||
card_str += cards[index] * count
|
||||
return card_str
|
||||
|
||||
def pick_chain(hand_list, count):
|
||||
chains = []
|
||||
str_card = [card for card in INDEX]
|
||||
hand_list = [str(card) for card in hand_list]
|
||||
hand = ''.join(hand_list[:12])
|
||||
chain_list = hand.split('0')
|
||||
add = 0
|
||||
for index, chain in enumerate(chain_list):
|
||||
if len(chain) > 0:
|
||||
if len(chain) >= 5:
|
||||
start = index + add
|
||||
min_count = int(min(chain)) // count
|
||||
if min_count != 0:
|
||||
str_chain = ''
|
||||
for num in range(len(chain)):
|
||||
str_chain += str_card[start+num]
|
||||
hand_list[start+num] = int(hand_list[start+num]) - int(min(chain))
|
||||
for _ in range(min_count):
|
||||
chains.append(str_chain)
|
||||
add += len(chain)
|
||||
hand_list = [int(card) for card in hand_list]
|
||||
return (chains, hand_list)
|
||||
|
||||
def combine_cards(hand):
|
||||
'''Get optimal combinations of cards in hand
|
||||
'''
|
||||
comb = {'rocket': [], 'bomb': [], 'trio': [], 'trio_chain': [],
|
||||
'solo_chain': [], 'pair_chain': [], 'pair': [], 'solo': []}
|
||||
# 1. pick rocket
|
||||
if hand[-2:] == 'BR':
|
||||
comb['rocket'].append('BR')
|
||||
hand = hand[:-2]
|
||||
# 2. pick bomb
|
||||
hand_cp = hand
|
||||
for index in range(len(hand_cp) - 3):
|
||||
if hand_cp[index] == hand_cp[index+3]:
|
||||
bomb = hand_cp[index: index+4]
|
||||
comb['bomb'].append(bomb)
|
||||
hand = hand.replace(bomb, '')
|
||||
# 3. pick trio and trio_chain
|
||||
hand_cp = hand
|
||||
for index in range(len(hand_cp) - 2):
|
||||
if hand_cp[index] == hand_cp[index+2]:
|
||||
trio = hand_cp[index: index+3]
|
||||
if len(comb['trio']) > 0 and INDEX[trio[-1]] < 12 and (INDEX[trio[-1]]-1) == INDEX[comb['trio'][-1][-1]]:
|
||||
comb['trio'][-1] += trio
|
||||
else:
|
||||
comb['trio'].append(trio)
|
||||
hand = hand.replace(trio, '')
|
||||
only_trio = []
|
||||
only_trio_chain = []
|
||||
for trio in comb['trio']:
|
||||
if len(trio) == 3:
|
||||
only_trio.append(trio)
|
||||
else:
|
||||
only_trio_chain.append(trio)
|
||||
comb['trio'] = only_trio
|
||||
comb['trio_chain'] = only_trio_chain
|
||||
# 4. pick solo chain
|
||||
hand_list = card_str2list(hand)
|
||||
chains, hand_list = pick_chain(hand_list, 1)
|
||||
comb['solo_chain'] = chains
|
||||
# 5. pick par_chain
|
||||
chains, hand_list = pick_chain(hand_list, 2)
|
||||
comb['pair_chain'] = chains
|
||||
hand = list2card_str(hand_list)
|
||||
# 6. pick pair and solo
|
||||
index = 0
|
||||
while index < len(hand) - 1:
|
||||
if hand[index] == hand[index+1]:
|
||||
comb['pair'].append(hand[index] + hand[index+1])
|
||||
index += 2
|
||||
else:
|
||||
comb['solo'].append(hand[index])
|
||||
index += 1
|
||||
if index == (len(hand) - 1):
|
||||
comb['solo'].append(hand[index])
|
||||
return comb
|
|
@ -0,0 +1,73 @@
|
|||
from douzero.env.game import GameEnv
|
||||
from .deep_agent import DeepAgent
|
||||
|
||||
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
||||
|
||||
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
|
||||
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
|
||||
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
|
||||
|
||||
AllEnvCard = [3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7,
|
||||
8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 11, 11, 12,
|
||||
12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 14, 17, 17, 17, 17, 20, 30]
|
||||
|
||||
|
||||
def evaluate(landlord, landlord_up, landlord_down):
|
||||
# 输入玩家的牌
|
||||
user_hand_cards_real = input("请输入你的手牌, 例如 333456789TJQKA2XD:")
|
||||
# user_hand_cards_real = "34666777899TJJKA22XD"
|
||||
use_hand_cards_env = [RealCard2EnvCard[c] for c in list(user_hand_cards_real)]
|
||||
# 输入玩家角色
|
||||
user_position_code = int(input("请输入你的角色[0:地主上家, 1:地主, 2:地主下家]:"))
|
||||
# user_position_code = 1
|
||||
user_position = ['landlord_up', 'landlord', 'landlord_down'][user_position_code]
|
||||
# 输入三张底牌
|
||||
three_landlord_cards_real = input("请输入三张底牌, 例如 2XD:")
|
||||
# three_landlord_cards_real = "2XD"
|
||||
three_landlord_cards_env = [RealCard2EnvCard[c] for c in list(three_landlord_cards_real)]
|
||||
|
||||
# 整副牌减去玩家手上的牌,就是其他人的手牌,再分配给另外两个角色(如何分配对AI判断没有影响)
|
||||
other_hand_cards = []
|
||||
for i in set(AllEnvCard):
|
||||
other_hand_cards.extend([i] * (AllEnvCard.count(i) - use_hand_cards_env.count(i)))
|
||||
|
||||
card_play_data_list = [{}]
|
||||
card_play_data_list[0].update({
|
||||
'three_landlord_cards': three_landlord_cards_env,
|
||||
['landlord_up', 'landlord', 'landlord_down'][(user_position_code + 0) % 3]: use_hand_cards_env,
|
||||
['landlord_up', 'landlord', 'landlord_down'][(user_position_code + 1) % 3]: other_hand_cards[0:17] if (user_position_code + 1) % 3 != 1 else other_hand_cards[17:],
|
||||
['landlord_up', 'landlord', 'landlord_down'][(user_position_code + 2) % 3]: other_hand_cards[0:17] if (user_position_code + 1) % 3 == 1 else other_hand_cards[17:]
|
||||
})
|
||||
# 生成手牌结束,校验手牌数量
|
||||
if len(card_play_data_list[0]["three_landlord_cards"]) != 3:
|
||||
print("底牌必须是3张\n")
|
||||
return
|
||||
if len(card_play_data_list[0]["landlord_up"]) != 17 or \
|
||||
len(card_play_data_list[0]["landlord_down"]) != 17 or \
|
||||
len(card_play_data_list[0]["landlord"]) != 20:
|
||||
print("初始手牌数目有误\n")
|
||||
return
|
||||
|
||||
# print(card_play_data_list)
|
||||
card_play_model_path_dict = {
|
||||
'landlord': landlord,
|
||||
'landlord_up': landlord_up,
|
||||
'landlord_down': landlord_down}
|
||||
|
||||
print("创建代表玩家的AI...")
|
||||
players = {}
|
||||
players[user_position] = DeepAgent(user_position, card_play_model_path_dict[user_position])
|
||||
|
||||
env = GameEnv(players)
|
||||
for idx, card_play_data in enumerate(card_play_data_list):
|
||||
env.card_play_init(card_play_data)
|
||||
print("开始出牌\n")
|
||||
while not env.game_over:
|
||||
env.step()
|
||||
print("{}胜,本局结束!\n".format("农民" if env.winner == "farmer" else "地主"))
|
||||
env.reset()
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,570 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Created by: Raf
|
||||
# Modify by: Vincentzyx
|
||||
|
||||
import GameHelper as gh
|
||||
from GameHelper import GameHelper
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import pyautogui
|
||||
import win32gui
|
||||
from PIL import Image
|
||||
import multiprocessing as mp
|
||||
|
||||
from PyQt5 import QtGui, QtWidgets, QtCore
|
||||
from PyQt5.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem, QGraphicsPixmapItem, QInputDialog, QMessageBox
|
||||
from PyQt5.QtGui import QPixmap, QIcon
|
||||
from PyQt5.QtCore import QTime, QEventLoop
|
||||
from MainWindow import Ui_Form
|
||||
|
||||
from douzero.env.game import GameEnv
|
||||
from douzero.evaluation.deep_agent import DeepAgent
|
||||
|
||||
import BidModel
|
||||
import LandlordModel
|
||||
import FarmerModel
|
||||
|
||||
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
||||
|
||||
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
|
||||
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
|
||||
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
|
||||
|
||||
AllEnvCard = [3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7,
|
||||
8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 11, 11, 12,
|
||||
12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 14, 17, 17, 17, 17, 20, 30]
|
||||
|
||||
AllCards = ['rD', 'bX', 'b2', 'r2', 'bA', 'rA', 'bK', 'rK', 'bQ', 'rQ', 'bJ', 'rJ', 'bT', 'rT',
|
||||
'b9', 'r9', 'b8', 'r8', 'b7', 'r7', 'b6', 'r6', 'b5', 'r5', 'b4', 'r4', 'b3', 'r3']
|
||||
|
||||
helper = GameHelper()
|
||||
helper.ScreenZoomRate = 1.25 # 请修改屏幕缩放比
|
||||
|
||||
class MyPyQT_Form(QtWidgets.QWidget, Ui_Form):
|
||||
def __init__(self):
|
||||
super(MyPyQT_Form, self).__init__()
|
||||
self.setupUi(self)
|
||||
self.setWindowFlags(QtCore.Qt.WindowMinimizeButtonHint | # 使能最小化按钮
|
||||
QtCore.Qt.WindowCloseButtonHint) # 窗体总在最前端 QtCore.Qt.WindowStaysOnTopHint
|
||||
self.setFixedSize(self.width(), self.height()) # 固定窗体大小
|
||||
# self.setWindowIcon(QIcon('pics/favicon.ico'))
|
||||
window_pale = QtGui.QPalette()
|
||||
# window_pale.setBrush(self.backgroundRole(), QtGui.QBrush(QtGui.QPixmap("pics/bg.png")))
|
||||
self.setPalette(window_pale)
|
||||
|
||||
self.Players = [self.RPlayer, self.Player, self.LPlayer]
|
||||
self.counter = QTime()
|
||||
|
||||
# 参数
|
||||
self.MyConfidence = 0.95 # 我的牌的置信度
|
||||
self.OtherConfidence = 0.9 # 别人的牌的置信度
|
||||
self.WhiteConfidence = 0.95 # 检测白块的置信度
|
||||
self.LandlordFlagConfidence = 0.9 # # 检测地主标志的置信度
|
||||
self.ThreeLandlordCardsConfidence = 0.9 # 检测地主底牌的置信度
|
||||
self.PassConfidence = 0.85
|
||||
self.WaitTime = 1 # 等待状态稳定延时
|
||||
self.MyFilter = 40 # 我的牌检测结果过滤参数
|
||||
self.OtherFilter = 25 # 别人的牌检测结果过滤参数
|
||||
self.SleepTime = 0.1 # 循环中睡眠时间
|
||||
self.RunGame = False
|
||||
self.AutoPlay = False
|
||||
# 坐标
|
||||
self.MyHandCardsPos = (250, 764, 1141, 70) # 我的截图区域
|
||||
self.LPlayedCardsPos = (463, 355, 380, 250) # 左边截图区域
|
||||
self.RPlayedCardsPos = (946, 355, 380, 250) # 右边截图区域
|
||||
self.LandlordFlagPos = [(1281, 276, 110, 140), (267, 695, 110, 140), (424, 237, 110, 140)] # 地主标志截图区域(右-我-左)
|
||||
self.ThreeLandlordCardsPos = (763, 37, 287, 136) # 地主底牌截图区域,resize成349x168
|
||||
self.PassBtnPos = (686, 659, 419, 100)
|
||||
self.GeneralBtnPos = (616, 631, 576, 117)
|
||||
# 信号量
|
||||
self.shouldExit = 0 # 通知上一轮记牌结束
|
||||
self.canRecord = threading.Lock() # 开始记牌
|
||||
self.card_play_model_path_dict = {
|
||||
'landlord': "baselines/douzero_ADP/landlord.ckpt",
|
||||
'landlord_up': "baselines/douzero_ADP/landlord_up.ckpt",
|
||||
'landlord_down': "baselines/douzero_ADP/landlord_down.ckpt"
|
||||
}
|
||||
# cards = self.find_three_landlord_cards(self.ThreeLandlordCardsPos)
|
||||
# print(cards)
|
||||
# exit()
|
||||
|
||||
def init_display(self):
|
||||
self.WinRate.setText("评分")
|
||||
self.InitCard.setText("开始")
|
||||
self.UserHandCards.setText("手牌")
|
||||
self.LPlayedCard.setText("上家出牌区域")
|
||||
self.RPlayedCard.setText("下家出牌区域")
|
||||
self.PredictedCard.setText("AI出牌区域")
|
||||
self.ThreeLandlordCards.setText("地主牌")
|
||||
self.SwitchMode.setText("自动" if self.AutoPlay else "单局")
|
||||
for player in self.Players:
|
||||
player.setStyleSheet('background-color: rgba(255, 0, 0, 0);')
|
||||
|
||||
def switch_mode(self):
|
||||
self.AutoPlay = not self.AutoPlay
|
||||
self.SwitchMode.setText("自动" if self.AutoPlay else "单局")
|
||||
|
||||
def init_cards(self):
|
||||
self.RunGame = True
|
||||
GameHelper.Interrupt = False
|
||||
self.init_display()
|
||||
# 玩家手牌
|
||||
self.user_hand_cards_real = ""
|
||||
self.user_hand_cards_env = []
|
||||
# 其他玩家出牌
|
||||
self.other_played_cards_real = ""
|
||||
self.other_played_cards_env = []
|
||||
# 其他玩家手牌(整副牌减去玩家手牌,后续再减掉历史出牌)
|
||||
self.other_hand_cards = []
|
||||
# 三张底牌
|
||||
self.three_landlord_cards_real = ""
|
||||
self.three_landlord_cards_env = []
|
||||
# 玩家角色代码:0-地主上家, 1-地主, 2-地主下家
|
||||
self.user_position_code = None
|
||||
self.user_position = ""
|
||||
# 开局时三个玩家的手牌
|
||||
self.card_play_data_list = {}
|
||||
# 出牌顺序:0-玩家出牌, 1-玩家下家出牌, 2-玩家上家出牌
|
||||
self.play_order = 0
|
||||
|
||||
self.env = None
|
||||
|
||||
# 识别玩家手牌
|
||||
self.user_hand_cards_real = self.find_my_cards(self.MyHandCardsPos)
|
||||
self.UserHandCards.setText(self.user_hand_cards_real)
|
||||
self.user_hand_cards_env = [RealCard2EnvCard[c] for c in list(self.user_hand_cards_real)]
|
||||
# 识别三张底牌
|
||||
self.three_landlord_cards_real = self.find_three_landlord_cards(self.ThreeLandlordCardsPos)
|
||||
self.ThreeLandlordCards.setText("底牌:" + self.three_landlord_cards_real)
|
||||
self.three_landlord_cards_env = [RealCard2EnvCard[c] for c in list(self.three_landlord_cards_real)]
|
||||
for testCount in range(1, 5):
|
||||
if len(self.three_landlord_cards_env) > 3:
|
||||
self.ThreeLandlordCardsConfidence += 0.05
|
||||
elif len(self.three_landlord_cards_env) < 3:
|
||||
self.ThreeLandlordCardsConfidence -= 0.05
|
||||
else:
|
||||
break
|
||||
self.three_landlord_cards_real = self.find_three_landlord_cards(self.ThreeLandlordCardsPos)
|
||||
self.ThreeLandlordCards.setText("底牌:" + self.three_landlord_cards_real)
|
||||
self.three_landlord_cards_env = [RealCard2EnvCard[c] for c in list(self.three_landlord_cards_real)]
|
||||
# 识别玩家的角色
|
||||
self.user_position_code = self.find_landlord(self.LandlordFlagPos)
|
||||
if self.user_position_code is None:
|
||||
items = ("地主上家", "地主", "地主下家")
|
||||
item, okPressed = QInputDialog.getItem(self, "选择角色", "未识别到地主,请手动选择角色:", items, 0, False)
|
||||
if okPressed and item:
|
||||
self.user_position_code = items.index(item)
|
||||
else:
|
||||
return
|
||||
self.user_position = ['landlord_up', 'landlord', 'landlord_down'][self.user_position_code]
|
||||
for player in self.Players:
|
||||
player.setStyleSheet('background-color: rgba(255, 0, 0, 0);')
|
||||
self.Players[self.user_position_code].setStyleSheet('background-color: rgba(255, 0, 0, 0.1);')
|
||||
|
||||
# 整副牌减去玩家手上的牌,就是其他人的手牌,再分配给另外两个角色(如何分配对AI判断没有影响)
|
||||
for i in set(AllEnvCard):
|
||||
self.other_hand_cards.extend([i] * (AllEnvCard.count(i) - self.user_hand_cards_env.count(i)))
|
||||
self.card_play_data_list.update({
|
||||
'three_landlord_cards': self.three_landlord_cards_env,
|
||||
['landlord_up', 'landlord', 'landlord_down'][(self.user_position_code + 0) % 3]:
|
||||
self.user_hand_cards_env,
|
||||
['landlord_up', 'landlord', 'landlord_down'][(self.user_position_code + 1) % 3]:
|
||||
self.other_hand_cards[0:17] if (self.user_position_code + 1) % 3 != 1 else self.other_hand_cards[17:],
|
||||
['landlord_up', 'landlord', 'landlord_down'][(self.user_position_code + 2) % 3]:
|
||||
self.other_hand_cards[0:17] if (self.user_position_code + 1) % 3 == 1 else self.other_hand_cards[17:]
|
||||
})
|
||||
print("开始对局")
|
||||
print("手牌:",self.user_hand_cards_real)
|
||||
print("地主牌:",self.three_landlord_cards_real)
|
||||
# 生成手牌结束,校验手牌数量
|
||||
if len(self.card_play_data_list["three_landlord_cards"]) != 3:
|
||||
QMessageBox.critical(self, "底牌识别出错", "底牌必须是3张!", QMessageBox.Yes, QMessageBox.Yes)
|
||||
self.init_display()
|
||||
return
|
||||
if len(self.card_play_data_list["landlord_up"]) != 17 or \
|
||||
len(self.card_play_data_list["landlord_down"]) != 17 or \
|
||||
len(self.card_play_data_list["landlord"]) != 20:
|
||||
QMessageBox.critical(self, "手牌识别出错", "初始手牌数目有误", QMessageBox.Yes, QMessageBox.Yes)
|
||||
self.init_display()
|
||||
return
|
||||
# 得到出牌顺序
|
||||
self.play_order = 0 if self.user_position == "landlord" else 1 if self.user_position == "landlord_up" else 2
|
||||
|
||||
# 创建一个代表玩家的AI
|
||||
ai_players = [0, 0]
|
||||
ai_players[0] = self.user_position
|
||||
ai_players[1] = DeepAgent(self.user_position, self.card_play_model_path_dict[self.user_position])
|
||||
|
||||
self.env = GameEnv(ai_players)
|
||||
try:
|
||||
self.start()
|
||||
except:
|
||||
self.stop()
|
||||
|
||||
def sleep(self, ms):
|
||||
self.counter.restart()
|
||||
while self.counter.elapsed() < ms:
|
||||
QtWidgets.QApplication.processEvents(QEventLoop.AllEvents, 50)
|
||||
|
||||
def start(self):
|
||||
self.env.card_play_init(self.card_play_data_list)
|
||||
print("开始出牌\n")
|
||||
while not self.env.game_over:
|
||||
# 玩家出牌时就通过智能体获取action,否则通过识别获取其他玩家出牌
|
||||
if self.play_order == 0:
|
||||
self.PredictedCard.setText("...")
|
||||
action_message = self.env.step(self.user_position)
|
||||
# 更新界面
|
||||
self.UserHandCards.setText("手牌:" + str(''.join(
|
||||
[EnvCard2RealCard[c] for c in self.env.info_sets[self.user_position].player_hand_cards]))[::-1])
|
||||
|
||||
self.PredictedCard.setText(action_message["action"] if action_message["action"] else "不出")
|
||||
self.WinRate.setText("评分:" + action_message["win_rate"])
|
||||
print("\n手牌:", str(''.join(
|
||||
[EnvCard2RealCard[c] for c in self.env.info_sets[self.user_position].player_hand_cards])))
|
||||
print("出牌:", action_message["action"] if action_message["action"] else "不出", ", 胜率:",
|
||||
action_message["win_rate"])
|
||||
if action_message["action"] == "":
|
||||
helper.ClickOnImage("pass_btn", region=self.PassBtnPos)
|
||||
else:
|
||||
helper.SelectCards(action_message["action"])
|
||||
tryCount = 20
|
||||
result = helper.LocateOnScreen("play_card", region=self.PassBtnPos, confidence=0.85)
|
||||
while result is None and tryCount > 0:
|
||||
if not self.RunGame:
|
||||
break
|
||||
print("等待出牌按钮")
|
||||
self.detect_start_btn()
|
||||
tryCount -= 1
|
||||
result = helper.LocateOnScreen("play_card", region=self.PassBtnPos, confidence=0.85)
|
||||
self.sleep(100)
|
||||
helper.ClickOnImage("play_card", region=self.PassBtnPos, confidence=0.85)
|
||||
self.sleep(2200)
|
||||
self.detect_start_btn()
|
||||
self.play_order = 1
|
||||
elif self.play_order == 1:
|
||||
self.RPlayedCard.setText("...")
|
||||
pass_flag = helper.LocateOnScreen('pass',
|
||||
region=self.RPlayedCardsPos,
|
||||
confidence=self.PassConfidence)
|
||||
self.detect_start_btn()
|
||||
while self.RunGame and self.have_white(self.RPlayedCardsPos) == 0 and pass_flag is None:
|
||||
print("等待下家出牌")
|
||||
self.sleep(100)
|
||||
pass_flag = helper.LocateOnScreen('pass', region=self.RPlayedCardsPos,
|
||||
confidence=self.PassConfidence)
|
||||
self.detect_start_btn()
|
||||
self.sleep(200)
|
||||
# 未找到"不出"
|
||||
if pass_flag is None:
|
||||
# 识别下家出牌
|
||||
self.RPlayedCard.setText("等待动画")
|
||||
self.sleep(1200)
|
||||
self.RPlayedCard.setText("识别中")
|
||||
self.other_played_cards_real = self.find_other_cards(self.RPlayedCardsPos)
|
||||
print("下家出牌", self.other_played_cards_real)
|
||||
self.sleep(500)
|
||||
# 找到"不出"
|
||||
else:
|
||||
self.other_played_cards_real = ""
|
||||
print("\n下家出牌:", self.other_played_cards_real)
|
||||
self.other_played_cards_env = [RealCard2EnvCard[c] for c in list(self.other_played_cards_real)]
|
||||
self.env.step(self.user_position, self.other_played_cards_env)
|
||||
# 更新界面
|
||||
self.RPlayedCard.setText(self.other_played_cards_real if self.other_played_cards_real else "不出")
|
||||
self.play_order = 2
|
||||
self.sleep(500)
|
||||
elif self.play_order == 2:
|
||||
self.LPlayedCard.setText("...")
|
||||
self.detect_start_btn()
|
||||
pass_flag = helper.LocateOnScreen('pass', region=self.LPlayedCardsPos,
|
||||
confidence=self.PassConfidence)
|
||||
while self.RunGame and self.have_white(self.LPlayedCardsPos) == 0 and pass_flag is None:
|
||||
print("等待上家出牌")
|
||||
self.detect_start_btn()
|
||||
self.sleep(100)
|
||||
pass_flag = helper.LocateOnScreen('pass', region=self.LPlayedCardsPos,
|
||||
confidence=self.PassConfidence)
|
||||
self.sleep(200)
|
||||
# 不出
|
||||
# 未找到"不出"
|
||||
if pass_flag is None:
|
||||
# 识别上家出牌
|
||||
self.LPlayedCard.setText("等待动画")
|
||||
self.sleep(1200)
|
||||
self.LPlayedCard.setText("识别中")
|
||||
self.other_played_cards_real = self.find_other_cards(self.LPlayedCardsPos)
|
||||
# 找到"不出"
|
||||
else:
|
||||
self.other_played_cards_real = ""
|
||||
print("\n上家出牌:", self.other_played_cards_real)
|
||||
self.other_played_cards_env = [RealCard2EnvCard[c] for c in list(self.other_played_cards_real)]
|
||||
self.env.step(self.user_position, self.other_played_cards_env)
|
||||
self.play_order = 0
|
||||
# 更新界面
|
||||
self.LPlayedCard.setText(self.other_played_cards_real if self.other_played_cards_real else "不出")
|
||||
self.sleep(500)
|
||||
else:
|
||||
pass
|
||||
self.sleep(100)
|
||||
|
||||
print("{}胜,本局结束!\n".format("农民" if self.env.winner == "farmer" else "地主"))
|
||||
# QMessageBox.information(self, "本局结束", "{}胜!".format("农民" if self.env.winner == "farmer" else "地主"),
|
||||
# QMessageBox.Yes, QMessageBox.Yes)
|
||||
self.detect_start_btn()
|
||||
|
||||
def find_landlord(self, landlord_flag_pos):
|
||||
for pos in landlord_flag_pos:
|
||||
result = helper.LocateOnScreen("landlord_words", region=pos,
|
||||
confidence=self.LandlordFlagConfidence)
|
||||
if result is not None:
|
||||
return landlord_flag_pos.index(pos)
|
||||
return None
|
||||
|
||||
def detect_start_btn(self):
|
||||
result = helper.LocateOnScreen("change_player_btn", region=(667, 741, 934, 404))
|
||||
if result is not None:
|
||||
self.RunGame = False
|
||||
self.stop()
|
||||
result = helper.LocateOnScreen("yes_btn", region=(680, 661, 435, 225))
|
||||
if result is not None:
|
||||
helper.ClickOnImage("yes_btn", region=(680, 661, 435, 225))
|
||||
self.sleep(1000)
|
||||
result = helper.LocateOnScreen("get_award_btn", region=(680, 661, 435, 225))
|
||||
if result is not None:
|
||||
helper.ClickOnImage("get_award_btn", region=(680, 661, 435, 225))
|
||||
self.sleep(1000)
|
||||
result = helper.LocateOnScreen("yes_btn_sm", region=(669, 583, 468, 100))
|
||||
if result is not None:
|
||||
helper.ClickOnImage("yes_btn_sm", region=(669, 583, 468, 100))
|
||||
self.sleep(200)
|
||||
|
||||
|
||||
def find_three_landlord_cards(self, pos):
|
||||
img, _ = helper.Screenshot()
|
||||
img = img.crop((pos[0], pos[1], pos[0] + pos[2], pos[1] + pos[3]))
|
||||
img = img.resize((349, 168))
|
||||
three_landlord_cards_real = ""
|
||||
for card in AllCards:
|
||||
result = pyautogui.locateAll(needleImage=helper.Pics['o' + card], haystackImage=img,
|
||||
confidence=self.ThreeLandlordCardsConfidence)
|
||||
three_landlord_cards_real += card[1] * self.cards_filter(list(result), self.OtherFilter)
|
||||
if len(three_landlord_cards_real) > 3:
|
||||
three_landlord_cards_real = ""
|
||||
for card in AllCards:
|
||||
result = pyautogui.locateAll(needleImage=helper.Pics['o' + card], haystackImage=img,
|
||||
confidence=self.ThreeLandlordCardsConfidence + 0.05)
|
||||
three_landlord_cards_real += card[1] * self.cards_filter(list(result), self.OtherFilter)
|
||||
if len(three_landlord_cards_real) < 3:
|
||||
three_landlord_cards_real = ""
|
||||
for card in AllCards:
|
||||
result = pyautogui.locateAll(needleImage=helper.Pics['o' + card], haystackImage=img,
|
||||
confidence=self.ThreeLandlordCardsConfidence + 0.1)
|
||||
three_landlord_cards_real += card[1] * self.cards_filter(list(result), self.OtherFilter)
|
||||
return three_landlord_cards_real
|
||||
|
||||
def find_my_cards(self, pos):
|
||||
user_hand_cards_real = ""
|
||||
img, _ = helper.Screenshot()
|
||||
cards, _ = helper.GetCards(img)
|
||||
for c in cards:
|
||||
user_hand_cards_real += c[0]
|
||||
# for card in AllCards:
|
||||
# result = pyautogui.locateAll(needleImage=helper.Pics['m'+card], haystackImage=img, confidence=self.MyConfidence)
|
||||
# user_hand_cards_real += card[1] * self.cards_filter(list(result), self.MyFilter)
|
||||
return user_hand_cards_real
|
||||
|
||||
def find_other_cards(self, pos):
|
||||
other_played_cards_real = ""
|
||||
self.sleep(500)
|
||||
img, _ = helper.Screenshot(region=pos)
|
||||
for card in AllCards:
|
||||
result = pyautogui.locateAll(needleImage=helper.Pics['o' + card], haystackImage=img,
|
||||
confidence=self.OtherConfidence)
|
||||
other_played_cards_real += card[1] * self.cards_filter(list(result), self.OtherFilter)
|
||||
return other_played_cards_real
|
||||
|
||||
def cards_filter(self, location, distance): # 牌检测结果滤波
|
||||
if len(location) == 0:
|
||||
return 0
|
||||
locList = [location[0][0]]
|
||||
count = 1
|
||||
for e in location:
|
||||
flag = 1 # “是新的”标志
|
||||
for have in locList:
|
||||
if abs(e[0] - have) <= distance:
|
||||
flag = 0
|
||||
break
|
||||
if flag:
|
||||
count += 1
|
||||
locList.append(e[0])
|
||||
return count
|
||||
|
||||
def have_white(self, pos): # 是否有白块
|
||||
img, _ = helper.Screenshot()
|
||||
result = pyautogui.locate(needleImage=helper.Pics["white"], haystackImage=img,
|
||||
region=pos, confidence=self.WhiteConfidence)
|
||||
if result is None:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
def stop(self):
|
||||
try:
|
||||
self.RunGame = False
|
||||
self.env.game_over = True
|
||||
self.env.reset()
|
||||
self.init_display()
|
||||
self.PreWinrate.setText("局前预估胜率:")
|
||||
self.BidWinrate.setText("叫牌预估胜率:")
|
||||
except AttributeError as e:
|
||||
pass
|
||||
if self.AutoPlay:
|
||||
play_btn = helper.LocateOnScreen("change_player_btn", region=(667, 741, 934, 404))
|
||||
while play_btn is None and self.AutoPlay:
|
||||
play_btn = helper.LocateOnScreen("change_player_btn", region=(667, 741, 934, 404))
|
||||
self.sleep(100)
|
||||
if play_btn is not None:
|
||||
helper.LeftClick((play_btn[0], play_btn[1]))
|
||||
self.beforeStart()
|
||||
|
||||
def beforeStart(self):
|
||||
GameHelper.Interrupt = True
|
||||
thresholds = [
|
||||
[75, 60],
|
||||
[85, 70]
|
||||
]
|
||||
while True:
|
||||
outterBreak = False
|
||||
jiaodizhu_btn = helper.LocateOnScreen("jiaodizhu_btn", region=(765, 663, 116, 50))
|
||||
qiangdizhu_btn = helper.LocateOnScreen("qiangdizhu_btn", region=(783, 663, 116, 50))
|
||||
jiabei_btn = helper.LocateOnScreen("jiabei_btn", region=self.GeneralBtnPos)
|
||||
self.detect_start_btn()
|
||||
while jiaodizhu_btn is None and qiangdizhu_btn is None and jiabei_btn is None:
|
||||
self.detect_start_btn()
|
||||
print("等待加倍或叫地主")
|
||||
self.sleep(100)
|
||||
jiaodizhu_btn = helper.LocateOnScreen("jiaodizhu_btn", region=(765, 663, 116, 50))
|
||||
qiangdizhu_btn = helper.LocateOnScreen("qiangdizhu_btn", region=(783, 663, 116, 50))
|
||||
jiabei_btn = helper.LocateOnScreen("jiabei_btn", region=self.GeneralBtnPos)
|
||||
if jiabei_btn is None:
|
||||
img, _ = helper.Screenshot()
|
||||
cards, _ = helper.GetCards(img)
|
||||
cards_str = "".join([card[0] for card in cards])
|
||||
win_rate = BidModel.predict(cards_str)
|
||||
print("预计叫地主胜率:", win_rate)
|
||||
self.BidWinrate.setText("叫牌预估胜率:" + str(round(win_rate, 2)) + "%")
|
||||
is_stolen = 0
|
||||
if jiaodizhu_btn is not None:
|
||||
if win_rate > 55:
|
||||
helper.ClickOnImage("jiaodizhu_btn", region=(765, 663, 116, 50), confidence=0.9)
|
||||
else:
|
||||
helper.ClickOnImage("bujiao_btn", region=self.GeneralBtnPos)
|
||||
elif qiangdizhu_btn is not None:
|
||||
is_stolen = 1
|
||||
if win_rate > 60:
|
||||
helper.ClickOnImage("qiangdizhu_btn", region=(783, 663, 116, 50), confidence=0.9)
|
||||
else:
|
||||
helper.ClickOnImage("buqiang_btn", region=self.GeneralBtnPos)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
llcards = self.find_three_landlord_cards(self.ThreeLandlordCardsPos)
|
||||
print("地主牌:", llcards)
|
||||
img, _ = helper.Screenshot()
|
||||
cards, _ = helper.GetCards(img)
|
||||
cards_str = "".join([card[0] for card in cards])
|
||||
if len(cards_str) == 20:
|
||||
win_rate = LandlordModel.predict(cards_str)
|
||||
self.PreWinrate.setText("局前预估胜率:" + str(round(win_rate, 2)) + "%")
|
||||
print("预估地主胜率:", win_rate)
|
||||
else:
|
||||
user_position_code = self.find_landlord(self.LandlordFlagPos)
|
||||
user_position = "up"
|
||||
while user_position_code is None:
|
||||
user_position_code = self.find_landlord(self.LandlordFlagPos)
|
||||
self.sleep(50)
|
||||
user_position = ['up', 'landlord', 'down'][user_position_code]
|
||||
win_rate = FarmerModel.predict(cards_str, llcards, user_position) - 5
|
||||
print("预估农民胜率:", win_rate)
|
||||
self.PreWinrate.setText("局前预估胜率:" + str(round(win_rate, 2)) + "%")
|
||||
if win_rate > thresholds[is_stolen][0]:
|
||||
chaojijiabei_btn = helper.LocateOnScreen("chaojijiabei_btn", region=self.GeneralBtnPos)
|
||||
if chaojijiabei_btn is not None:
|
||||
helper.ClickOnImage("chaojijiabei_btn", region=self.GeneralBtnPos)
|
||||
else:
|
||||
helper.ClickOnImage("jiabei_btn", region=self.GeneralBtnPos)
|
||||
elif win_rate > thresholds[is_stolen][1]:
|
||||
helper.ClickOnImage("jiabei_btn", region=self.GeneralBtnPos)
|
||||
else:
|
||||
helper.ClickOnImage("bujiabei_btn", region=self.GeneralBtnPos)
|
||||
outterBreak = True
|
||||
break
|
||||
if outterBreak:
|
||||
break
|
||||
|
||||
llcards = self.find_three_landlord_cards(self.ThreeLandlordCardsPos)
|
||||
while len(llcards) != 3:
|
||||
print("等待地主牌", llcards)
|
||||
self.sleep(100)
|
||||
llcards = self.find_three_landlord_cards(self.ThreeLandlordCardsPos)
|
||||
|
||||
self.sleep(4000)
|
||||
self.init_cards()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = QtWidgets.QApplication(sys.argv)
|
||||
app.setStyleSheet("""
|
||||
QPushButton{
|
||||
text-align : center;
|
||||
background-color : white;
|
||||
font: bold;
|
||||
border-color: gray;
|
||||
border-width: 2px;
|
||||
border-radius: 10px;
|
||||
padding: 6px;
|
||||
height : 14px;
|
||||
border-style: outset;
|
||||
font : 14px;
|
||||
}
|
||||
QPushButton:hover{
|
||||
background-color : light gray;
|
||||
}
|
||||
QPushButton:pressed{
|
||||
text-align : center;
|
||||
background-color : gray;
|
||||
font: bold;
|
||||
border-color: gray;
|
||||
border-width: 2px;
|
||||
border-radius: 10px;
|
||||
padding: 6px;
|
||||
height : 14px;
|
||||
border-style: outset;
|
||||
font : 14px;
|
||||
padding-left:9px;
|
||||
padding-top:9px;
|
||||
}
|
||||
QComboBox{
|
||||
background:transparent;
|
||||
border: 1px solid rgba(200, 200, 200, 100);
|
||||
font-weight: bold;
|
||||
}
|
||||
QComboBox:drop-down{
|
||||
border: 0px;
|
||||
}
|
||||
QComboBox QAbstractItemView:item{
|
||||
height: 30px;
|
||||
}
|
||||
QLabel{
|
||||
background:transparent;
|
||||
font-weight: bold;
|
||||
}
|
||||
""")
|
||||
my_pyqt_form = MyPyQT_Form()
|
||||
my_pyqt_form.show()
|
||||
sys.exit(app.exec_())
|
After Width: | Height: | Size: 1.6 MiB |
After Width: | Height: | Size: 3.9 KiB |
After Width: | Height: | Size: 2.9 KiB |
After Width: | Height: | Size: 3.2 KiB |
After Width: | Height: | Size: 1.9 KiB |
After Width: | Height: | Size: 1002 B |
After Width: | Height: | Size: 206 B |
After Width: | Height: | Size: 309 B |
After Width: | Height: | Size: 242 B |
After Width: | Height: | Size: 4.3 KiB |
After Width: | Height: | Size: 5.0 KiB |
After Width: | Height: | Size: 7.1 KiB |
After Width: | Height: | Size: 3.4 KiB |
After Width: | Height: | Size: 3.0 KiB |
After Width: | Height: | Size: 3.3 KiB |
After Width: | Height: | Size: 15 KiB |
After Width: | Height: | Size: 4.1 KiB |
After Width: | Height: | Size: 2.2 KiB |
After Width: | Height: | Size: 2.3 KiB |
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 2.2 KiB |
After Width: | Height: | Size: 2.4 KiB |
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 2.5 KiB |
After Width: | Height: | Size: 2.5 KiB |
After Width: | Height: | Size: 2.3 KiB |
After Width: | Height: | Size: 1.8 KiB |
After Width: | Height: | Size: 2.2 KiB |
After Width: | Height: | Size: 2.2 KiB |
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 1.8 KiB |
After Width: | Height: | Size: 2.4 KiB |
After Width: | Height: | Size: 2.5 KiB |
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 2.4 KiB |
After Width: | Height: | Size: 2.6 KiB |
After Width: | Height: | Size: 2.1 KiB |
After Width: | Height: | Size: 2.6 KiB |
After Width: | Height: | Size: 2.6 KiB |
After Width: | Height: | Size: 2.4 KiB |
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 1.8 KiB |
After Width: | Height: | Size: 2.4 KiB |
After Width: | Height: | Size: 2.3 KiB |
After Width: | Height: | Size: 2.1 KiB |
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 1.8 KiB |
After Width: | Height: | Size: 2.0 KiB |
After Width: | Height: | Size: 2.1 KiB |
After Width: | Height: | Size: 1.7 KiB |
After Width: | Height: | Size: 2.1 KiB |
After Width: | Height: | Size: 2.1 KiB |
After Width: | Height: | Size: 2.0 KiB |