Source code for magnet.training.utils

import torch, pickle

import magnet as mag

[docs]def load_state(module, path, alternative_name=None): r"""Loads the state_dict of a PyTorch object from a specified path. This is a more robust version of the of the PyTorch way in the sense that the device mapping is automatically handled. Args: module (object): Any PyTorch object that has a state_dict path (pathlib.Path): The path to folder containing the state_dict file alternative_name (str or None): A fallback name for the file if the module object does not have a name attribute. Default: ``None`` Raises: RuntimeError: If no :attr:`alternative_name` is provided and the module does not have a name. .. note:: If you already know the file name, set :attr:`alternative_name` to that. This is just a convinience method that assumes that the file name will be the same as the name of the module (if there is one). """ name = alternative_name if not hasattr(module, 'name') else module.name if name is None: raise RuntimeError('Module Name is None!') filepath = path / (name + '.pt') device = 'cuda:0' if mag.device.type == 'cuda' else 'cpu' # Needed patch if filepath.exists(): module.load_state_dict(torch.load(filepath, map_location=device))
[docs]def save_state(module, path, alternative_name=None): r"""Saves the state_dict of a PyTorch object to a specified path. Args: module (object): Any PyTorch object that has a state_dict path (pathlib.Path): The path to a folder to save the state_dict to alternative_name (str or None): A fallback name for the file if the module object does not have a name attribute. Default: ``None`` Raises: RuntimeError: If no :attr:`alternative_name` is provided and the module does not have a name. """ name = alternative_name if not hasattr(module, 'name') else module.name if name is None: raise RuntimeError('Module Name is None!') path.mkdir(parents=True, exist_ok=True) filepath = path / (name + '.pt') torch.save(module.state_dict(), filepath)
[docs]def load_object(path, **kwargs): r"""A convinience method to unpickle a file. Args: path (pathlib.Path): The path to the pickle file Keyword Args: default (object): A default value to be returned if the file does not exist. Default: ``None`` Raises: RuntimeError: If a default keyword argument is not provided and the file is not found. """ if path.exists(): with open(path, 'rb') as f: return pickle.load(f) elif 'default' in kwargs.keys(): return kwargs['default'] else: raise RuntimeError(f'The path {path} does not exist. No default provided either.')
[docs]def save_object(obj, path): r"""A convinience method to pickle an object. Args: obj (object): The object to pickle path (pathlib.Path): The path to save to .. note:: If the :attr:`path` does not exists, it is created. """ path.parent.mkdir(parents=True, exist_ok=True) with open(path, 'wb') as f: pickle.dump(obj, f)