magnet.training

Trainer

class magnet.training.Trainer(models, optimizers)[source]

Abstract base class for training models.

The Trainer class makes it incredibly simple and convinient to train, monitor, debug and checkpoint entire Deep Learning projects.

Simply define your training loop by implementing the optimize() method.

Parameters:
  • models (list of nn.Module) – All the models that need to be trained
  • optimizers (list of optim.Optimizer) – Any optimizers that are used

Note

If any model is in eval() model, the trainer is set off. This means that as per protocol, all models will not train.

Variables:callbacks (list) – A list of callbacks attached to the trainer.

Take a look at SupervisedTrainer for an idea on how to extend this class.

optimize()[source]

Defines the core optimization loop. This method is called on each iteration.

Two quick protocols that one needs to follow are:

1. Do NOT actually backpropagate or step() the optimizers if the trainer is not training. Use the is_training() method to find out. This is essential since this will ensure that the trainer behaves as expected when is_training() is False. Useful, for example, in cases like callbacks.ColdStart

2. Send a callback the signal 'gradient' with a keyword argument 'models' that is the list of models that accumulate a gradient. Usually, it’s all the modules (self.modules).

Any callbacks that listen to this signal are interested in the gradient information (eg. callbacks.Babysitter).

train(dataloader, epochs=1, callbacks=None, **kwargs)[source]

Starts the training process.

Parameters:
  • dataloader (DataLoader) – The MagNet dataloader that iterates over the training set
  • epochs (float or int) – The number of epochs to train for. Default: 1
  • callbacks (list) – Any callbacks to be attached. Default: None
Keyword Arguments:
 

iterations (int) – The number of iterations to train for. Overrides epochs.

Note

PyTorch DataLoader s are not supported.

Ideally, encapsulate your dataset in the Data class.

mock(path=None)[source]

A context manager that creates a temporary ‘safe’ scope for training.

All impact to stateful objects (models, optimizers and the trainer itself) are forgotten once out of this scope.

This is very useful if you need to try out what-if experiments.

Parameters:path (pathlib.Path) – The path to save temporary states into Default: {System temp directory}/.mock_trainer
epochs(mode=None)[source]

The number of epochs completed.

Parameters:mode (str or None) – If the mode is 'start' or 'end', a boolean is returned signalling if it’s the start or end of an epoch
register_parameter(name, value)[source]

Use this to register ‘stateful’ parameters that are serialized

SupervisedTrainer

class magnet.training.SupervisedTrainer(model, optimizer='adam', loss='cross_entropy', metrics=None)[source]

A simple trainer that implements a supervised approach where a simple model \(\hat{y} = f(x)\) is trained to map \(\hat{y}\) to ground-truth \(y\) according to some specified loss.

This is the training routine that most high-level deep learning frameworks implement.

Parameters:
  • model (nn.Module) – The model that needs to be trained
  • optimizer (str or optim.Optimzer) – The optimizer used to train the model. Default: 'adam'
  • loss (str or callable) – A loss function that gives the objective to be minimized. Default: 'cross_entropy'
  • metrics (list) – Any other metrics that need to be monitored. Default: None
  • optimizer can be an actual optim.Optimizer instance or the name of a popular optimzizer (eg. 'adam').
  • loss can be a function or the name of a popular loss function (eg. 'cross_entropy'). It should accept 2 arguments (\(\hat{y}\), \(y\)).
  • metrics should contain a list of functions which accept 2 arguments (\(\hat{y}\), \(y\)), like the loss function.

Note

A static validate() function is provided for the validation callback

Note

The metrics is of no use unless there is some callback (eg.``callbacks.Monitor``) to receive the metrics

Examples:

>>> import magnet as mag
>>> import magnet.nodes as mn

>>> from magnet.data import Data
>>> from magnet.training import callbacks, SupervisedTrainer

>>> data = Data.get('mnist')

>>> model = mn.Linear(10, act=None)
>>> model.build(x=next(data())[0])

>>> trainer = SupervisedTrainer(model)
>>> callbacks=[callbacks.Monitor(),
               callbacks.Validate(data(64, mode='val'), SupervisedTrainer.validate)]
>>> trainer.train(data(64, shuffle=True), 1, callbacks)
magnet.training.finish_training(path, names=None)[source]

A helper function for cleaning up the training logs and other checkpoints and retaining only the state_dicts of the trained models.

Parameters:
  • path (pathlib.Path) – The path where the trainer was checkpointed
  • names (list) – The names of the models in the order given to the trainer. Default: None
  • names can be used if the models themselves did not have names prior to training. The checkpoints default to an ordered naming scheme. If passed, the files are additionally renamed to these names.

Note

Does nothing / fails silently if the path does not exist.

Example:

>>> # Assume that we've defined two models - encoder and decoder,
>>> # and a suitable trainer. The models do not have a 'name' attribute.

>>> trainer.save_state(checkpoint_path / 'my-trainer')

>>> # Suppose the checkpoint directory contains the following files:
>>> # my-trainer/
>>> #     models/
>>> #         0.pt
>>> #         1.pt
>>> #     callbacks/
>>> #         monitor/
>>> #         babysitter/
>>> #     state.p

>>> finish_training(path, names=['encoder', 'decoder'])

>>> # Now the directory contains these files:
>>> # encoder.pt
>>> # decoder.pt

magnet.training.callbacks

CallbackQueue

class magnet.training.callbacks.CallbackQueue[source]

A container for multiple callbacks that can be called in parallel.

If multiple callbacks need to be called together (as intended), they can be registered via this class.

Since callbacks need to be unique (by their name), this class also ensures that there are no duplicates.

__call__(signal, *args, **kwargs)[source]

Broadcasts a signal to all registered callbacks along with payload arguments.

Parameters:signal (object) – Any object that is broadcast as a signal.

Note

Any other arguments will be sent as-is to the callbacks.

find(name)[source]

Scans through the registered list and finds the callback with name.

If not found, returns None.

Raises:RuntimeError – If multiple callbacks are found.

Monitor

class magnet.training.callbacks.Monitor(frequency=10, show_progress=True, **kwargs)[source]

Allows easy monitoring of the training process.

Stores any metric / quantity broadcast using the 'write_stats' signal.

Also adds a nice progress bar!

Parameters:
  • frequency (int) – Then number of times per epoch to flush the buffer. Default: 10
  • show_progress (bool) – If True, adds a progress bar. Default: True
Keyword Arguments:
 

name (str) – Name of this callback. Default: 'monitor'

  • frequency is useful only if there are buffered metrics.

Examples:

>>> import torch

>>> import magnet as mag
>>> import magnet.nodes as mn

>>> from magnet.training import callbacks, SupervisedTrainer

>>> model = mn.Linear(10, act=None)
>>> with mag.eval(model): model(torch.randn(4, 1, 28, 28))

>>> trainer = SupervisedTrainer(model)

>>> callbacks = callbacks.CallbackQueue([callbacks.Monitor()])
>>> callbacks(signal='write_stats', trainer=trainer, key='loss', value=0.1)

>>> callbacks[0].history
{'loss': [{'val': 0.1}]}
__call__(trainer, signal, **kwargs)[source]

Responds to the following signals:

  • 'write_stats': Any keyword arguments will be passed to the History.append() method.
  • 'on_training_start': To be called before start of training. Initializes the progress bar.
  • 'on_batch_start': Called before the training loop. Updates the progress bar.
  • 'on_batch_end': Called after the training loop. Flushes the history buffer if needed and sets the progress bar description.
  • 'on_training_end': To be called after training. Closes the progress bar.
  • 'load_state': Loads the state of this callback from path.
  • 'save_state': Saves the state of this callback to path.
show(metric=None, log=False, x_key='epochs', **kwargs)[source]

Calls the corresponding History.show() method.

Validate

class magnet.training.callbacks.Validate(dataloader, validate, frequency=10, batches=None, drop_last=False, **kwargs)[source]

Runs a validation function over a dataset during the course of training.

Most Machine Learning research uses a held out validation set as a proxy for the test set / real-life data. Hyperparameters are usually tuned on the validation set.

Often, this is done during training in order to view the simultaneous learning on the validation set and catch any overfitting / underfitting.

This callback enables you to run a custom validate function over a dataloader.

Parameters:
  • dataloader (DataLoader) – DataLoader containing the validation set
  • validate (bool) – A callable that does the validation
  • frequency (int) – Then number of times per epoch to run the function. Default: \(10\)
  • batches (int or None) – The number of times / batches to call the validate function in each run. Default: None
  • drop_last (bool) – If True, the last batch is not run. Default: False
Keyword Arguments:
 

name (str) – Name of this callback. Default: 'validate'

  • validate is a function which takes two arguments: (trainer, dataloader).

  • batches defaults to a value which ensures that an epoch of the validation set matches an epoch of the training set.

    For instance, if the training set has \(80\) datapoints and the validation set has \(20\) and the batch size is \(1\) for both, an epoch consists of \(80\) iterations for the training set and \(20\) for the validation set.

    If the validate function is run \(10\) times(frequency) per epoch of the training set, then batches must be \(2\).

__call__(trainer, signal, **kwargs)[source]

Responds to the following signals:

  • 'on_training_start': To be called before start of training. Automatically finds the number of batches per run.
  • 'on_batch_end': Called after the training loop. Calls the validate function.
  • 'on_training_end': To be called after training. If drop_last, calls the validate function.
  • 'load_state': Loads the state of this callback from path.
  • 'save_state': Saves the state of this callback to path.

Checkpoint

class magnet.training.callbacks.Checkpoint(path, interval='5 m', **kwargs)[source]

Serializes stateful objects during the training process.

For many practical Deep Learning projects, training takes many hours, even days.

As such, it is only natural that you’d want to save the progress every once in a while.

This callback saves the models, optimizers, schedulers and the trainer itself periodically and automatically loads from those states if found.

Parameters:
  • path (pathlib.Path) – The root path to save to
  • interval (str) – The time between checkpoints. Default: ‘5 m’
Keyword Arguments:
 

name (str) – Name of this callback. Default: 'checkpoint'

  • interval should be a string of the form '{duration} {unit}'. Valid units are: 'us' (microseconds), 'ms' (milliseconds), 's' (seconds), 'm' (minutes)’, 'h' (hours), 'd' (days).
__call__(trainer, signal, **kwargs)[source]

Responds to the following signals:

  • 'on_training_start': To be called before start of training. Creates the path if it doesn’t exist and loads from it if it does. Also sets the starting time.
  • 'on_batch_end': Called after the training loop. Checkpoints if the interval is crossed and resets the clock.
  • 'on_training_end': To be called after training. Checkpoints one last time.
  • 'load_state': Loads the state of this callback from path.
  • 'save_state': Saves the state of this callback to path.

ColdStart

class magnet.training.callbacks.ColdStart(epochs=0.1, **kwargs)[source]

Starts the trainer in eval mode for a few iterations.

Sometimes, you may want to find out how the model performs prior to any training. This callback freezes the training initially.

Parameters:epochs (float) – The number of epochs to freeze the trainer. Default: \(0.1\)
Keyword Arguments:
 name (str) – Name of this callback. Default: 'coldstart'
__call__(trainer, signal, **kwargs)[source]

Responds to the following signals:

  • 'on_training_start': To be called before start of training. Sets the models in eval mode.
  • 'on_batch_end': Called after the training loop. If the epochs is exhausted, unfreezes the trainer and removes this callback from the queue.

LRScheduler

class magnet.training.callbacks.LRScheduler(scheduler, **kwargs)[source]

A helper callback to add in optimizer schedulers.

Parameters:scheduler (LRScheduler) – The scheduler.
Keyword Arguments:
 name (str) – Name of this callback. Default: 'lr_scheduler'
__call__(trainer, signal, **kwargs)[source]

Responds to the following signals:

  • 'on_batch_start': Called before the training loop. If it is the start of an epoch, steps the scheduler.

magnet.training.history

class magnet.training.history.History[source]

A dictionary-like repository which is used to store several metrics of interest in training in the form of snapshots.

This object can be utilized to collect, store and analyze training metrics against a variety of features of interest (epochs, iterations, time etc.)

Since this is a subclass of dict, it can be used as such. However, it is prefered to operate it using the class-specific methods.

Examples:

>>> history = History()

>>> # Add a simple value with a time stamp.
>>> # This is like the statement: history['loss'] = 69
>>> # However, any additional stamps can also be attached.
>>> history.append('loss', 69, time=time())
{'loss': [{'val': 69, 'time': 1535095251.6717412}]}

>>> history.clear()

>>> # Use a small buffer-size of 10.
>>> # This means that only the latest 10 values are kept.
>>> for i in range(100): history.append('loss', i, buffer_size=10)

>>> # Flush the buffer with a time stamp.
>>> history.flush(time=time())

>>> # The mean of the last 10 values is now stored.
{'loss': [{'val': 94.5, 'time': 1535095320.9745226}]}
find(key)[source]

A helper method that returns a filtered dictionary with a search key.

Parameters:key (str) – The filter key

Examples:

>>> # Assume the history is empty with keys: ['loss', 'val_loss',
>>> # 'encoder_loss', 'accuracy', 'wierd-metric']

>>> history.find('loss')
{'loss': [], 'val_loss': [], 'encoder_loss': []}
append(key, value, validation=False, buffer_size=None, **stamps)[source]

Append a new snapshot to the history.

Parameters:
  • key (str) – The dictionary key / name of the object
  • value (object) – The actual object
  • valdiation (bool) – Whether this is a validation metric. Default: False
  • buffer_size (int or None) – The size of the buffer of the key. Default: None
  • validation is just a convinient key-modifier. It appends 'val_' to the key.

  • buffer_size defines the size of the storage buffer for the specific key.

    The latest buffer_size snapshots are stored.

    If None, the key is stored as is.

Note

Any further keyword arguments define stamps that are essentially the signatures for the snapshot.

show(key=None, log=False, x_key=None, validation=True, legend=None, **kwargs)[source]

Plot the snapshots for a key against a stamp.

Parameters:
  • key (str) – The key of the record
  • log (bool) – If True, the y-axis will be log-scaled. Default: False
  • x_key (str or None) – The stamp to use as the x-axis. Default: None
  • validation (bool) – Whether to plot the validation records (if they exist) as well. Default: True
  • legend (str or None) – The legend entry. Default: None
Keyword Arguments:
 
  • ax (pyplot axes object) – The axis to plot into. Default: None
  • smoothen (bool) – If True, smoothens the plot. Default: True
  • window_fraction (float) – How much of the plot to use as a window for smoothing. Default: \(0.3\)
  • gain (float) – How much more dense to make the plot. Default: \(10\)
  • replace_outliers (bool) – If True, replaces outlier datapoints by a sensible value. Default: True
  • key can be None, in which case this method is successively called for all existing keys. The log attribute is overriden, however. It is only set to True for any key with 'loss' in it.
  • legend can be None, in which case the default legends 'training' and 'validation' are applied respectively.
flush(key=None, **stamps)[source]

Flush the buffer (if exists) and append the mean.

Parameters:key (str or None) – The key to flush. Default: None
  • key can be None, in which case this method is successively called for all existing keys.

Note

Any further keyword arguments define stamps that are essentially the signatures for the snapshot.

class magnet.training.history.SnapShot(buffer_size=-1)[source]

A list of stamped values (snapshots).

This is used by the History object to store a repository of training metrics.

Parameters:buffer_size (int) – The size of the buffer. Default: \(-1\)
  • If buffer_size is negative, then the snapshots are stored as is.
append(value, buffer=False, **stamps)[source]

Add a new snapshot.

Parameters:
  • value (object) – The value to add
  • buffer (bool) – If True, adds to the buffer instead. Default: False

Note

Any further keyword arguments define stamps that are essentially the signatures for the snapshot.

flush(**stamps)[source]

Flush the buffer (if exists) and append the mean.

Note

Any keyword arguments define stamps that are essentially the signatures for the snapshot.

show(ax, x=None, label=None, **kwargs)[source]

Plot the snapshots against a stamp.

Parameters:
  • ax (pyplot axes object) – The axis to plot into
  • x (str or None) – The stamp to use as the x-axis. Default: None
  • label (str or None) – The label for the line. Default: None
  • key can be None, in which case this method is successively called for all existing keys. The log attribute is overriden, however. It is only set to True for any key with 'loss' in it.
  • legend can be None, in which case the default legends 'training' and 'validation' are applied respectively.
Keyword Arguments:
 () – See History.show() for more details.

Note

Any further keyword arguments are passed to the plot function.

magnet.training.utils

magnet.training.utils.load_object(path, **kwargs)[source]

A convinience method to unpickle a file.

Parameters:path (pathlib.Path) – The path to the pickle file
Keyword Arguments:
 default (object) – A default value to be returned if the file does not exist. Default: None
Raises:RuntimeError – If a default keyword argument is not provided and the file is not found.
magnet.training.utils.load_state(module, path, alternative_name=None)[source]

Loads the state_dict of a PyTorch object from a specified path.

This is a more robust version of the of the PyTorch way in the sense that the device mapping is automatically handled.

Parameters:
  • module (object) – Any PyTorch object that has a state_dict
  • path (pathlib.Path) – The path to folder containing the state_dict file
  • alternative_name (str or None) – A fallback name for the file if the module object does not have a name attribute. Default: None
Raises:

RuntimeError – If no alternative_name is provided and the module does not have a name.

Note

If you already know the file name, set alternative_name to that.

This is just a convinience method that assumes that the file name will be the same as the name of the module (if there is one).

magnet.training.utils.save_object(obj, path)[source]

A convinience method to pickle an object.

Parameters:

Note

If the path does not exists, it is created.

magnet.training.utils.save_state(module, path, alternative_name=None)[source]

Saves the state_dict of a PyTorch object to a specified path.

Parameters:
  • module (object) – Any PyTorch object that has a state_dict
  • path (pathlib.Path) – The path to a folder to save the state_dict to
  • alternative_name (str or None) – A fallback name for the file if the module object does not have a name attribute. Default: None
Raises:

RuntimeError – If no alternative_name is provided and the module does not have a name.