vanilla压缩模型评估逻辑
This commit is contained in:
parent
55578c620e
commit
240cd8f0cc
|
@ -475,6 +475,11 @@ model_dict['landlord'] = LandlordLstmModel
|
|||
model_dict['landlord_up'] = FarmerLstmModel
|
||||
model_dict['landlord_front'] = FarmerLstmModel
|
||||
model_dict['landlord_down'] = FarmerLstmModel
|
||||
model_dict_lite = {}
|
||||
model_dict_lite['landlord'] = LandlordLstmModelLite
|
||||
model_dict_lite['landlord_up'] = FarmerLstmModelLite
|
||||
model_dict_lite['landlord_front'] = FarmerLstmModelLite
|
||||
model_dict_lite['landlord_down'] = FarmerLstmModelLite
|
||||
model_dict_legacy = {}
|
||||
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
|
||||
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy
|
||||
|
|
|
@ -7,7 +7,7 @@ from onnxruntime.datasets import get_example
|
|||
from douzero.env.env import get_obs
|
||||
|
||||
def _load_model(position, model_path, model_type, use_legacy, use_lite):
|
||||
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy
|
||||
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite
|
||||
model = None
|
||||
if model_type == "general":
|
||||
if use_lite:
|
||||
|
@ -18,7 +18,10 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite):
|
|||
if use_legacy:
|
||||
model = model_dict_legacy[position]()
|
||||
else:
|
||||
model = model_dict[position]()
|
||||
if use_lite:
|
||||
model = model_dict_lite[position]()
|
||||
else:
|
||||
model = model_dict[position]()
|
||||
model_state_dict = model.state_dict()
|
||||
if torch.cuda.is_available():
|
||||
pretrained = torch.load(model_path, map_location='cuda:0')
|
||||
|
|
Loading…
Reference in New Issue