vanilla压缩模型评估逻辑

This commit is contained in:
zhiyang7 2021-12-24 09:49:36 +08:00
parent 55578c620e
commit 240cd8f0cc
2 changed files with 10 additions and 2 deletions

View File

@ -475,6 +475,11 @@ model_dict['landlord'] = LandlordLstmModel
model_dict['landlord_up'] = FarmerLstmModel
model_dict['landlord_front'] = FarmerLstmModel
model_dict['landlord_down'] = FarmerLstmModel
model_dict_lite = {}
model_dict_lite['landlord'] = LandlordLstmModelLite
model_dict_lite['landlord_up'] = FarmerLstmModelLite
model_dict_lite['landlord_front'] = FarmerLstmModelLite
model_dict_lite['landlord_down'] = FarmerLstmModelLite
model_dict_legacy = {}
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy

View File

@ -7,7 +7,7 @@ from onnxruntime.datasets import get_example
from douzero.env.env import get_obs
def _load_model(position, model_path, model_type, use_legacy, use_lite):
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite
model = None
if model_type == "general":
if use_lite:
@ -18,7 +18,10 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite):
if use_legacy:
model = model_dict_legacy[position]()
else:
model = model_dict[position]()
if use_lite:
model = model_dict_lite[position]()
else:
model = model_dict[position]()
model_state_dict = model.state_dict()
if torch.cuda.is_available():
pretrained = torch.load(model_path, map_location='cuda:0')