| | """Tools to save/restore model from checkpoints.""" |
| |
|
| | import argparse |
| | import shutil |
| | import sys |
| | import os |
| | import re |
| | import json |
| | import time |
| |
|
| | import torch |
| |
|
| | CHECKPOINT_PATTERN = re.compile('^model_checkpoint-(\d+)$') |
| |
|
| |
|
| | class ArgsDict(dict): |
| |
|
| | def __init__(self, **kwargs): |
| | super(ArgsDict, self).__init__() |
| | for key, value in kwargs.items(): |
| | self[key] = value |
| | self.__dict__ = self |
| |
|
| |
|
| | def load_checkpoint(item_dict, model_dir, map_location=None, step=None): |
| | """ item_dict: {"model": model, "opt1": opt1, ...}""" |
| | path = os.path.join(model_dir, 'model_checkpoint') |
| | if step is not None: |
| | path += '-{:08d}'.format(step) |
| | if os.path.exists(path): |
| | print("Loading model from %s" % path) |
| | checkpoint = torch.load(path, map_location=map_location) |
| |
|
| | old_state_dict = item_dict["model"].state_dict() |
| | for key in old_state_dict.keys(): |
| | if key not in checkpoint['model']: |
| | checkpoint['model'][key] = old_state_dict[key] |
| | |
| | for item_name in item_dict: |
| | item_dict[item_name].load_state_dict(checkpoint[item_name]) |
| | return checkpoint.get('step', 0) |
| | return 0 |
| |
|
| |
|
| | def load_and_map_checkpoint(model, model_dir, remap): |
| | path = os.path.join(model_dir, 'model_checkpoint') |
| | print("Loading parameters %s from %s" % (remap.keys(), model_dir)) |
| | checkpoint = torch.load(path) |
| | new_state_dict = model.state_dict() |
| | for name, value in remap.items(): |
| | |
| | new_state_dict[name] = checkpoint['model'][value] |
| | model.load_state_dict(new_state_dict) |
| |
|
| |
|
| | def save_checkpoint(items, step, model_dir, ignore=[], |
| | keep_every_n=10000000): |
| | if not os.path.exists(model_dir): |
| | os.makedirs(model_dir) |
| | path_without_step = os.path.join(model_dir, 'model_checkpoint') |
| | step_padded = format(step, '08d') |
| | state_dict = items["model"].state_dict() |
| | if ignore: |
| | for key in state_dict.keys(): |
| | for item in ignore: |
| | if key.startswith(item): |
| | state_dict.pop(key) |
| | path_with_step = '{}-{}'.format(path_without_step, step_padded) |
| |
|
| | saved_dic = {} |
| | for key in items: |
| | saved_dic[key] = items[key].state_dict() |
| | torch.save({**saved_dic, "step": step}, path_with_step) |
| |
|
| | try: |
| | os.unlink(path_without_step) |
| | except FileNotFoundError: |
| | pass |
| | try: |
| | os.symlink(os.path.basename(path_with_step), path_without_step) |
| | except OSError: |
| | shutil.copy2(path_with_step, path_without_step) |
| |
|
| | |
| | if keep_every_n is not None: |
| | all_checkpoints = [] |
| | for name in os.listdir(model_dir): |
| | m = CHECKPOINT_PATTERN.match(name) |
| | if m is None or name == os.path.basename(path_with_step): |
| | continue |
| | checkpoint_step = int(m.group(1)) |
| | all_checkpoints.append((checkpoint_step, name)) |
| | all_checkpoints.sort() |
| |
|
| | last_step = float('-inf') |
| | for checkpoint_step, name in all_checkpoints: |
| | if checkpoint_step - last_step >= keep_every_n: |
| | last_step = checkpoint_step |
| | continue |
| | os.unlink(os.path.join(model_dir, name)) |
| |
|
| |
|
| | class Saver(object): |
| | """Class to manage save and restore for the model and optimizer.""" |
| |
|
| | def __init__(self, items, keep_every_n=None): |
| | assert type(items) == dict |
| | assert "model" in items |
| | self._items = items |
| | self._keep_every_n = keep_every_n |
| |
|
| | def restore(self, model_dir, map_location=None, |
| | step=None, item_keys=["model", "optimizer"]): |
| | """Restores model and optimizer from given directory. |
| | Specify what shoud be restored |
| | |
| | Returns: |
| | Last training step for the model restored. |
| | """ |
| | items2restore = { k: self._items[k] for k in item_keys} |
| | last_step = load_checkpoint( |
| | items2restore, model_dir, map_location, step) |
| | return last_step |
| |
|
| | def save(self, model_dir, step): |
| | """Saves model and optimizer to given directory. |
| | |
| | Args: |
| | model_dir: Model directory to save. |
| | step: Current training step. |
| | """ |
| | save_checkpoint(self._items, step, model_dir, |
| | keep_every_n=self._keep_every_n) |
| |
|
| | def restore_part(self, other_model_dir, remap): |
| | """Restores part of the model from other directory. |
| | |
| | Useful to initialize part of the model with another pretrained model. |
| | |
| | Args: |
| | other_model_dir: Model directory to load from. |
| | remap: dict, remapping current parameters to the other model's. |
| | """ |
| | load_and_map_checkpoint(self._items["model"], other_model_dir, remap) |
| |
|