修复BUG
This commit is contained in:
parent
b5982b7195
commit
016d77aeb0
|
@ -385,6 +385,11 @@ model_dict_new['landlord'] = GeneralModel
|
|||
model_dict_new['landlord_up'] = GeneralModel
|
||||
model_dict_new['landlord_front'] = GeneralModel
|
||||
model_dict_new['landlord_down'] = GeneralModel
|
||||
model_dict_new_lite = {}
|
||||
model_dict_new_lite['landlord'] = GeneralModelLite
|
||||
model_dict_new_lite['landlord_up'] = GeneralModelLite
|
||||
model_dict_new_lite['landlord_front'] = GeneralModelLite
|
||||
model_dict_new_lite['landlord_down'] = GeneralModelLite
|
||||
|
||||
class OldModel:
|
||||
"""
|
||||
|
|
|
@ -6,11 +6,14 @@ from onnxruntime.datasets import get_example
|
|||
|
||||
from douzero.env.env import get_obs
|
||||
|
||||
def _load_model(position, model_path, model_type, use_legacy):
|
||||
from douzero.dmc.models import model_dict_new, model_dict, model_dict_legacy
|
||||
def _load_model(position, model_path, model_type, use_legacy, use_lite):
|
||||
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy
|
||||
model = None
|
||||
if model_type == "general":
|
||||
model = model_dict_new[position]()
|
||||
if use_lite:
|
||||
model = model_dict_new_lite[position]()
|
||||
else:
|
||||
model = model_dict_new[position]()
|
||||
else:
|
||||
if use_legacy:
|
||||
model = model_dict_legacy[position]()
|
||||
|
@ -51,7 +54,7 @@ class DeepAgent:
|
|||
self.use_legacy = True if "legacy" in model_path else False
|
||||
self.lite_model = True if "lite" in model_path else False
|
||||
self.model_type = "general" if "resnet" in model_path else "old"
|
||||
self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
|
||||
self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model)
|
||||
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider'])
|
||||
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||
|
@ -62,10 +65,10 @@ class DeepAgent:
|
|||
|
||||
obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model)
|
||||
|
||||
z_batch = torch.from_numpy(obs['z_batch']).float()
|
||||
x_batch = torch.from_numpy(obs['x_batch']).float()
|
||||
if torch.cuda.is_available():
|
||||
z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
|
||||
# z_batch = torch.from_numpy(obs['z_batch']).float()
|
||||
# x_batch = torch.from_numpy(obs['x_batch']).float()
|
||||
# if torch.cuda.is_available():
|
||||
# z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
|
||||
# y_pred = self.model.forward(z_batch, x_batch)['values']
|
||||
# y_pred = y_pred.detach().cpu().numpy()
|
||||
y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0]
|
||||
|
|
Loading…
Reference in New Issue