优化资源占用
This commit is contained in:
parent
8ac4684607
commit
b2dd98cc30
|
@ -8,6 +8,9 @@ from douzero.env.env import get_obs
|
|||
|
||||
def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False):
|
||||
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite
|
||||
model_path = model_path + '.onnx'
|
||||
if use_onnx and os.path.exists(model_path):
|
||||
return None
|
||||
model = None
|
||||
if model_type == "general":
|
||||
if use_lite:
|
||||
|
@ -29,21 +32,19 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx
|
|||
model.load_state_dict(model_state_dict)
|
||||
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
|
||||
model.eval()
|
||||
model_path = model_path + '.onnx'
|
||||
if use_onnx:
|
||||
if not os.path.exists(model_path):
|
||||
onnx_params = model.get_onnx_params(torch.device('cpu'))
|
||||
torch.onnx.export(
|
||||
model,
|
||||
onnx_params['args'],
|
||||
model_path,
|
||||
export_params=True,
|
||||
opset_version=10,
|
||||
do_constant_folding=True,
|
||||
input_names=onnx_params['input_names'],
|
||||
output_names=onnx_params['output_names'],
|
||||
dynamic_axes=onnx_params['dynamic_axes']
|
||||
)
|
||||
if use_onnx and not os.path.exists(model_path):
|
||||
onnx_params = model.get_onnx_params(torch.device('cpu'))
|
||||
torch.onnx.export(
|
||||
model,
|
||||
onnx_params['args'],
|
||||
model_path,
|
||||
export_params=True,
|
||||
opset_version=10,
|
||||
do_constant_folding=True,
|
||||
input_names=onnx_params['input_names'],
|
||||
output_names=onnx_params['output_names'],
|
||||
dynamic_axes=onnx_params['dynamic_axes']
|
||||
)
|
||||
return None
|
||||
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue