import os.path import torch from douzero.dmc.file_writer import FileWriter from douzero.dmc.models import Model, OldModel from douzero.dmc.utils import get_batch, log, create_env, create_optimizers, act learner_model = Model(device="0") # lr=flags.learning_rate, # momentum=flags.momentum, # eps=flags.epsilon, # alpha=flags.alpha) class myflags: learning_rate=0.0003 momentum=0 alpha=0.99 epsilon=1e-5 flags = myflags checkpointpath = "merger/model.tar" merged_path = "merger/model_merged.tar" optimizers = create_optimizers(flags, learner_model) models = {} device_iterator = ["cpu"] for device in device_iterator: model = Model(device="cpu") model.share_memory() model.eval() models[device] = model checkpoint_states = torch.load( checkpointpath ) print("Load original weights") for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down'] learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k]) stats = checkpoint_states["stats"] print("Load replace weights") for k in ['landlord']: if not os.path.exists("merger/resnet_" + k + ".ckpt"): continue weights = torch.load("merger/resnet_" + k + ".ckpt", map_location="cuda:0") learner_model.get_model(k).load_state_dict(weights) learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) frames = checkpoint_states["frames"] # frames = 3085177900 position_frames = checkpoint_states["position_frames"] log.info(f"Resuming preempted job, current stats:\n{stats}") log.info('Saving checkpoint to %s', checkpointpath) _models = learner_model.get_models() torch.save({ 'model_state_dict': {k: _models[k].state_dict() for k in _models}, # {{"general": _models["landlord"].state_dict()} 'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers}, # {"general": optimizers["landlord"].state_dict()} "stats": stats, 'flags': checkpoint_states["flags"], 'frames': frames, 'position_frames': position_frames }, merged_path)