Douzero_Resnet/model_merger.py

62 lines
2.1 KiB
Python

import os.path
import torch
from douzero.dmc.file_writer import FileWriter
from douzero.dmc.models import Model, OldModel
from douzero.dmc.utils import get_batch, log, create_env, create_optimizers, act
learner_model = Model(device="0")
# lr=flags.learning_rate,
# momentum=flags.momentum,
# eps=flags.epsilon,
# alpha=flags.alpha)
class myflags:
learning_rate=0.0003
momentum=0
alpha=0.99
epsilon=1e-5
flags = myflags
checkpointpath = "merger/model.tar"
merged_path = "merger/model_merged.tar"
optimizers = create_optimizers(flags, learner_model)
models = {}
device_iterator = ["cpu"]
for device in device_iterator:
model = Model(device="cpu")
model.share_memory()
model.eval()
models[device] = model
checkpoint_states = torch.load(
checkpointpath
)
print("Load original weights")
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down']
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
stats = checkpoint_states["stats"]
print("Load replace weights")
for k in ['landlord']:
if not os.path.exists("merger/resnet_" + k + ".ckpt"):
continue
weights = torch.load("merger/resnet_" + k + ".ckpt", map_location="cuda:0")
learner_model.get_model(k).load_state_dict(weights)
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
frames = checkpoint_states["frames"]
# frames = 3085177900
position_frames = checkpoint_states["position_frames"]
log.info(f"Resuming preempted job, current stats:\n{stats}")
log.info('Saving checkpoint to %s', checkpointpath)
_models = learner_model.get_models()
torch.save({
'model_state_dict': {k: _models[k].state_dict() for k in _models}, # {{"general": _models["landlord"].state_dict()}
'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers}, # {"general": optimizers["landlord"].state_dict()}
"stats": stats,
'flags': checkpoint_states["flags"],
'frames': frames,
'position_frames': position_frames
}, merged_path)