调整评估模型加载逻辑
This commit is contained in:
parent
29866e5dbf
commit
9bdcc85cc7
|
@ -23,16 +23,11 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx
|
||||||
else:
|
else:
|
||||||
model = model_dict[position]()
|
model = model_dict[position]()
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
if torch.cuda.is_available():
|
|
||||||
pretrained = torch.load(model_path, map_location='cuda:0')
|
|
||||||
else:
|
|
||||||
pretrained = torch.load(model_path, map_location='cpu')
|
pretrained = torch.load(model_path, map_location='cpu')
|
||||||
pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict}
|
pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict}
|
||||||
model_state_dict.update(pretrained)
|
model_state_dict.update(pretrained)
|
||||||
model.load_state_dict(model_state_dict)
|
model.load_state_dict(model_state_dict)
|
||||||
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
|
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
|
||||||
if torch.cuda.is_available():
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model_path = model_path + '.onnx'
|
model_path = model_path + '.onnx'
|
||||||
if use_onnx and not os.path.exists(model_path):
|
if use_onnx and not os.path.exists(model_path):
|
||||||
|
@ -58,14 +53,19 @@ class DeepAgent:
|
||||||
self.lite_model = True if "lite" in model_path else False
|
self.lite_model = True if "lite" in model_path else False
|
||||||
self.model_type = "general" if "resnet" in model_path else "old"
|
self.model_type = "general" if "resnet" in model_path else "old"
|
||||||
self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model, use_onnx=use_onnx)
|
self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model, use_onnx=use_onnx)
|
||||||
if use_onnx:
|
self.onnx_model_path = os.path.abspath(model_path + '.onnx')
|
||||||
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider'])
|
self.use_onnx = use_onnx
|
||||||
else:
|
|
||||||
self.onnx_model = None
|
self.onnx_model = None
|
||||||
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||||
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
||||||
def act(self, infoset, with_confidence = False):
|
def act(self, infoset, with_confidence = False):
|
||||||
|
if self.use_onnx and self.onnx_model is None:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CPUExecutionProvider'])
|
||||||
|
else:
|
||||||
|
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CUDAExecutionProvider'])
|
||||||
|
|
||||||
if not with_confidence and len(infoset.legal_actions) == 1:
|
if not with_confidence and len(infoset.legal_actions) == 1:
|
||||||
return infoset.legal_actions[0]
|
return infoset.legal_actions[0]
|
||||||
|
|
||||||
|
|
|
@ -22,8 +22,7 @@ def load_card_play_models(card_play_model_path_dict):
|
||||||
players[position] = DeepAgent(position, card_play_model_path_dict[position], use_onnx=True)
|
players[position] = DeepAgent(position, card_play_model_path_dict[position], use_onnx=True)
|
||||||
return players
|
return players
|
||||||
|
|
||||||
def mp_simulate(card_play_data_list, card_play_model_path_dict, q, output, title):
|
def mp_simulate(card_play_data_list, players, q, output, title):
|
||||||
players = load_card_play_models(card_play_model_path_dict)
|
|
||||||
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||||
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
||||||
|
@ -97,11 +96,14 @@ def evaluate(landlord, landlord_up, landlord_front, landlord_down, eval_data, nu
|
||||||
ctx = mp.get_context('spawn')
|
ctx = mp.get_context('spawn')
|
||||||
q = ctx.SimpleQueue()
|
q = ctx.SimpleQueue()
|
||||||
processes = []
|
processes = []
|
||||||
|
|
||||||
|
players = load_card_play_models(card_play_model_path_dict)
|
||||||
|
|
||||||
for card_paly_data in card_play_data_list_each_worker:
|
for card_paly_data in card_play_data_list_each_worker:
|
||||||
|
|
||||||
p = ctx.Process(
|
p = ctx.Process(
|
||||||
target=mp_simulate,
|
target=mp_simulate,
|
||||||
args=(card_paly_data, card_play_model_path_dict, q, output, title))
|
args=(card_paly_data, players, q, output, title))
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
|
|
Loading…
Reference in New Issue