优化资源占用

This commit is contained in:
zhiyang7 2021-12-29 16:56:42 +08:00
parent 8ac4684607
commit b2dd98cc30
1 changed files with 16 additions and 15 deletions

View File

@ -8,6 +8,9 @@ from douzero.env.env import get_obs
def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False): def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False):
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite
model_path = model_path + '.onnx'
if use_onnx and os.path.exists(model_path):
return None
model = None model = None
if model_type == "general": if model_type == "general":
if use_lite: if use_lite:
@ -29,21 +32,19 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx
model.load_state_dict(model_state_dict) model.load_state_dict(model_state_dict)
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt")) # torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
model.eval() model.eval()
model_path = model_path + '.onnx' if use_onnx and not os.path.exists(model_path):
if use_onnx: onnx_params = model.get_onnx_params(torch.device('cpu'))
if not os.path.exists(model_path): torch.onnx.export(
onnx_params = model.get_onnx_params(torch.device('cpu')) model,
torch.onnx.export( onnx_params['args'],
model, model_path,
onnx_params['args'], export_params=True,
model_path, opset_version=10,
export_params=True, do_constant_folding=True,
opset_version=10, input_names=onnx_params['input_names'],
do_constant_folding=True, output_names=onnx_params['output_names'],
input_names=onnx_params['input_names'], dynamic_axes=onnx_params['dynamic_axes']
output_names=onnx_params['output_names'], )
dynamic_axes=onnx_params['dynamic_axes']
)
return None return None
return model return model