2021-12-11 20:07:41 +08:00
|
|
|
import os.path
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from douzero.dmc.file_writer import FileWriter
|
|
|
|
from douzero.dmc.models import Model, OldModel
|
|
|
|
from douzero.dmc.utils import get_batch, log, create_env, create_optimizers, act
|
|
|
|
|
|
|
|
learner_model = Model(device="0")
|
|
|
|
# lr=flags.learning_rate,
|
|
|
|
# momentum=flags.momentum,
|
|
|
|
# eps=flags.epsilon,
|
|
|
|
# alpha=flags.alpha)
|
|
|
|
class myflags:
|
|
|
|
learning_rate=0.0003
|
|
|
|
momentum=0
|
|
|
|
alpha=0.99
|
|
|
|
epsilon=1e-5
|
|
|
|
|
|
|
|
flags = myflags
|
|
|
|
checkpointpath = "merger/model.tar"
|
|
|
|
merged_path = "merger/model_merged.tar"
|
|
|
|
optimizers = create_optimizers(flags, learner_model)
|
|
|
|
models = {}
|
|
|
|
device_iterator = ["cpu"]
|
|
|
|
for device in device_iterator:
|
|
|
|
model = Model(device="cpu")
|
|
|
|
model.share_memory()
|
|
|
|
model.eval()
|
|
|
|
models[device] = model
|
|
|
|
checkpoint_states = torch.load(
|
|
|
|
checkpointpath
|
|
|
|
)
|
|
|
|
print("Load original weights")
|
2021-12-19 17:19:32 +08:00
|
|
|
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down']
|
2021-12-11 20:07:41 +08:00
|
|
|
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
|
|
|
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
|
|
|
stats = checkpoint_states["stats"]
|
|
|
|
|
|
|
|
print("Load replace weights")
|
|
|
|
for k in ['landlord']:
|
|
|
|
if not os.path.exists("merger/resnet_" + k + ".ckpt"):
|
|
|
|
continue
|
|
|
|
weights = torch.load("merger/resnet_" + k + ".ckpt", map_location="cuda:0")
|
|
|
|
learner_model.get_model(k).load_state_dict(weights)
|
|
|
|
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
|
|
|
|
|
|
|
frames = checkpoint_states["frames"]
|
|
|
|
# frames = 3085177900
|
|
|
|
position_frames = checkpoint_states["position_frames"]
|
|
|
|
log.info(f"Resuming preempted job, current stats:\n{stats}")
|
|
|
|
|
|
|
|
log.info('Saving checkpoint to %s', checkpointpath)
|
|
|
|
_models = learner_model.get_models()
|
|
|
|
|
|
|
|
torch.save({
|
|
|
|
'model_state_dict': {k: _models[k].state_dict() for k in _models}, # {{"general": _models["landlord"].state_dict()}
|
|
|
|
'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers}, # {"general": optimizers["landlord"].state_dict()}
|
|
|
|
"stats": stats,
|
|
|
|
'flags': checkpoint_states["flags"],
|
|
|
|
'frames': frames,
|
|
|
|
'position_frames': position_frames
|
|
|
|
}, merged_path)
|