修复BUG
This commit is contained in:
parent
2b607afc29
commit
c970a59655
|
@ -8,8 +8,8 @@ from douzero.env.env import get_obs
|
||||||
|
|
||||||
def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False):
|
def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False):
|
||||||
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite
|
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite
|
||||||
model_path = model_path + '.onnx'
|
model_path_onnx = model_path + '.onnx'
|
||||||
if use_onnx and os.path.exists(model_path):
|
if use_onnx and os.path.exists(model_path_onnx):
|
||||||
return None
|
return None
|
||||||
model = None
|
model = None
|
||||||
if model_type == "general":
|
if model_type == "general":
|
||||||
|
@ -30,14 +30,13 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx
|
||||||
pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict}
|
pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict}
|
||||||
model_state_dict.update(pretrained)
|
model_state_dict.update(pretrained)
|
||||||
model.load_state_dict(model_state_dict)
|
model.load_state_dict(model_state_dict)
|
||||||
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
|
|
||||||
model.eval()
|
model.eval()
|
||||||
if use_onnx and not os.path.exists(model_path):
|
if use_onnx and not os.path.exists(model_path_onnx):
|
||||||
onnx_params = model.get_onnx_params(torch.device('cpu'))
|
onnx_params = model.get_onnx_params(torch.device('cpu'))
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
model,
|
model,
|
||||||
onnx_params['args'],
|
onnx_params['args'],
|
||||||
model_path,
|
model_path_onnx,
|
||||||
export_params=True,
|
export_params=True,
|
||||||
opset_version=10,
|
opset_version=10,
|
||||||
do_constant_folding=True,
|
do_constant_folding=True,
|
||||||
|
|
Loading…
Reference in New Issue