Source code for magnet.training.train

from torch import optim
from contextlib import contextmanager

[docs]class Trainer: r"""Abstract base class for training models. The Trainer class makes it incredibly simple and convinient to train, monitor, debug and checkpoint entire Deep Learning projects. Simply define your training loop by implementing the :py:meth:`optimize` method. Args: models (list of :py:class:`nn.Module`): All the models that need to be trained optimizers (list of :py:class:`optim.Optimizer`): Any optimizers that are used .. note:: If any model is in eval() model, the trainer is *set off*. This means that as per protocol, *all* models will not train. Attributes: callbacks (list): A list of callbacks attached to the trainer. Take a look at :py:class:`SupervisedTrainer` for an idea on how to extend this class. """ def __init__(self, models, optimizers): self.models = models self.optimizers = optimizers self.parameters = set() self.register_parameter('iterations', 0)
[docs] def optimize(self): r""" Defines the core optimization loop. This method is called on each iteration. Two quick protocols that one needs to follow are: 1. **Do NOT** actually backpropagate or step() the optimizers if the trainer is not training. Use the :py:meth:`is_training` method to find out. This is essential since this will ensure that the trainer behaves as expected when :py:meth:`is_training` is ``False``. Useful, for example, in cases like :py:class:`callbacks.ColdStart` 2. Send a callback the signal ``'gradient'`` with a keyword argument ``'models'`` that is the list of models that accumulate a gradient. Usually, it's all the modules (``self.modules``). Any callbacks that listen to this signal are interested in the gradient information (eg. ``callbacks.Babysitter``). """ raise NotImplementedError
[docs] def train(self, dataloader, epochs=1, callbacks=None, **kwargs): r"""Starts the training process. Args: dataloader (``DataLoader``): The MagNet dataloader that iterates over the training set epochs (float or int): The number of epochs to train for. Default: ``1`` callbacks (list): Any callbacks to be attached. Default: ``None`` Keyword Args: iterations (int): The number of iterations to train for. Overrides :attr:`epochs`. .. note:: PyTorch ``DataLoader`` s are not supported. Ideally, encapsulate your dataset in the ``Data`` class. """ from magnet.training.callbacks import CallbackQueue self.dataloader = dataloader if callbacks is None: callbacks = [] self.callbacks = CallbackQueue(callbacks) total_iterations = kwargs.get('iterations', int(epochs * len(dataloader))) self.callbacks('on_training_start', trainer=self, total_iterations=total_iterations) for self.iterations in range(self.iterations, self.iterations + total_iterations): next(self) self.callbacks('on_training_end', trainer=self)
def __iter__(self): return self def __next__(self): self.callbacks('on_batch_start', trainer=self) self.optimize() self.callbacks('on_batch_end', trainer=self)
[docs] @contextmanager def mock(self, path=None): r"""A context manager that creates a temporary *'safe'* scope for training. All impact to stateful objects (models, optimizers and the trainer itself) are forgotten once out of this scope. This is very useful if you need to try out *what-if experiments*. Args: path (pathlib.Path): The path to save temporary states into Default: ``{System temp directory}/.mock_trainer`` """ from shutil import rmtree if path is None: from pathlib import Path from tempfile import gettempdir path = Path(gettempdir()) / '.mock_trainer' rmtree(path, ignore_errors=True) # Remove any existing directory self.save_state(path) try: yield finally: self.load_state(path) rmtree(path)
[docs] def epochs(self, mode=None): r"""The number of epochs completed. Args: mode (str or None): If the mode is ``'start'`` or ``'end'``, a boolean is returned signalling if it's the start or end of an epoch """ if mode is None: return self.iterations / len(self.dataloader) if mode == 'start': return (self.iterations / len(self.dataloader)).is_integer() if mode == 'end': return ((self.iterations + 1) / len(self.dataloader)).is_integer()
def is_training(self): return all(model.training for model in self.models) def load_state(self, path): from magnet.training.utils import load_state, load_object for i, model in enumerate(self.models): load_state(model, path / 'models', alternative_name=str(i)) for i, optimizer in enumerate(self.optimizers): load_state(optimizer, path / 'optimizers', alternative_name=str(i)) state_dict = load_object(path / 'state.p', default={}) for attr, val in state_dict.items(): self.register_parameter(attr, val) try: self.callbacks('load_state', trainer=self, path=path / 'callbacks') except AttributeError: pass try: self.dataloader.load_state_dict(path / 'dataloader.p') except AttributeError: pass def save_state(self, path): from magnet.training.utils import save_state, save_object for i, model in enumerate(self.models): save_state(model, path / 'models', alternative_name=str(i)) for i, optimizer in enumerate(self.optimizers): save_state(optimizer, path / 'optimizers', alternative_name=str(i)) state_dict = {attr: getattr(self, attr) for attr in self.parameters} save_object(state_dict, path / 'state.p') try: self.callbacks('save_state', trainer=self, path=path / 'callbacks') except AttributeError: pass try: self.dataloader.save_state_dict(path / 'dataloader.p') except AttributeError: pass
[docs] def register_parameter(self, name, value): r"""Use this to register *'stateful'* parameters that are serialized """ setattr(self, name, value) self.parameters.add(name)
[docs]class SupervisedTrainer(Trainer): r"""A simple trainer that implements a supervised approach where a simple model :math:`\hat{y} = f(x)` is trained to map :math:`\hat{y}` to ground-truth :math:`y` according to some specified loss. This is the training routine that most high-level deep learning frameworks implement. Args: model (``nn.Module``): The model that needs to be trained optimizer (str or optim.Optimzer): The optimizer used to train the model. Default: ``'adam'`` loss (str or ``callable``): A loss function that gives the objective to be minimized. Default: ``'cross_entropy'`` metrics (list): Any other metrics that need to be monitored. Default: ``None`` * :attr:`optimizer` can be an actual ``optim.Optimizer`` instance or the name of a popular optimzizer (eg. ``'adam'``). * :attr:`loss` can be a function or the name of a popular loss function (eg. ``'cross_entropy'``). It should accept 2 arguments (:math:`\hat{y}`, :math:`y`). * :attr:`metrics` should contain a list of functions which accept 2 arguments (:math:`\hat{y}`, :math:`y`), like the loss function. .. note:: A static :py:meth:`validate` function is provided for the validation callback .. note:: The :attr:`metrics` is of no use unless there is some callback (eg.``callbacks.Monitor``) to receive the metrics Examples:: >>> import magnet as mag >>> import magnet.nodes as mn >>> from magnet.data import Data >>> from magnet.training import callbacks, SupervisedTrainer >>> data = Data.get('mnist') >>> model = mn.Linear(10, act=None) >>> model.build(x=next(data())[0]) >>> trainer = SupervisedTrainer(model) >>> callbacks=[callbacks.Monitor(), callbacks.Validate(data(64, mode='val'), SupervisedTrainer.validate)] >>> trainer.train(data(64, shuffle=True), 1, callbacks) """ def __init__(self, model, optimizer='adam', loss='cross_entropy', metrics=None): from magnet.nodes.functional import wiki if isinstance(optimizer, str): optimizer = optimizer_wiki[optimizer.lower()](model.parameters()) if isinstance(loss, str): loss = wiki['losses'][loss.lower()] if metrics is None: metrics = [] if not isinstance(metrics, (tuple, list)): metrics = [metrics] for i, metric in enumerate(metrics): if isinstance(metric, str): metrics[i] = (metric, wiki['metrics'][metric.lower()]) super().__init__([model], [optimizer]) self.loss = loss self.metrics = metrics def optimize(self): optimizer = self.optimizers[0] loss = self.get_loss(self.dataloader) # Protocol 1: Backprop and step() only if trainer is training if self.is_training(): loss.backward() # Protocol 2: Broadcast the models that accumulate the gradient # using signal 'gradient' before clearing them. self.callbacks('gradient', trainer=self, models=self.models) optimizer.step() optimizer.zero_grad() @staticmethod def validate(trainer, dataloader): r"""Static helper method to validate models in :attr:`trainer` against data in :attr:`dataloader`. Can be passed to ``callbacks.Validate()``. """ trainer.get_loss(dataloader, validation=True) def get_loss(self, dataloader, validation=False): r"""Utility function that returns the loss and broadcasts metrics. """ def write_stats(key, value): self.callbacks('write_stats', trainer=self, key=key, value=value, validation=validation, buffer_size=len(dataloader)) model = self.models[0] x, y = next(dataloader) y_pred = model(x) loss = self.loss(y_pred, y) # Broadcast the loss and any other metrics using the 'write_stats' signal. write_stats('loss', loss.item()) for metric in self.metrics: write_stats(metric[0], metric[1](y_pred, y).item()) return loss
[docs]def finish_training(path, names=None): r""" A helper function for cleaning up the training logs and other checkpoints and retaining only the state_dicts of the trained models. Args: path (pathlib.Path): The path where the trainer was checkpointed names (list): The names of the models in the order given to the trainer. Default: ``None`` * :attr:`names` can be used if the models themselves did not have names prior to training. The checkpoints default to an ordered naming scheme. If passed, the files are additionally renamed to these names. .. note:: Does nothing / fails silently if the path does not exist. Example:: >>> # Assume that we've defined two models - encoder and decoder, >>> # and a suitable trainer. The models do not have a 'name' attribute. >>> trainer.save_state(checkpoint_path / 'my-trainer') >>> # Suppose the checkpoint directory contains the following files: >>> # my-trainer/ >>> # models/ >>> # 0.pt >>> # 1.pt >>> # callbacks/ >>> # monitor/ >>> # babysitter/ >>> # state.p >>> finish_training(path, names=['encoder', 'decoder']) >>> # Now the directory contains these files: >>> # encoder.pt >>> # decoder.pt """ if not path.exists(): return import shutil if isinstance(names, str): names = [names] filenames = list((path / 'models').glob('*.pt')) if names is None: names = [filename.stem for filename in filenames] for name, filename in zip(names, filenames): shutil.move(filename, path.parent / (name + '.pt')) shutil.rmtree(path)
optimizer_wiki = {'adam': optim.Adam}