magnet.training¶
Trainer¶
-
class
magnet.training.
Trainer
(models, optimizers)[source]¶ Abstract base class for training models.
The Trainer class makes it incredibly simple and convinient to train, monitor, debug and checkpoint entire Deep Learning projects.
Simply define your training loop by implementing the
optimize()
method.Parameters: - models (list of
nn.Module
) – All the models that need to be trained - optimizers (list of
optim.Optimizer
) – Any optimizers that are used
Note
If any model is in eval() model, the trainer is set off. This means that as per protocol, all models will not train.
Variables: callbacks (list) – A list of callbacks attached to the trainer. Take a look at
SupervisedTrainer
for an idea on how to extend this class.-
optimize
()[source]¶ Defines the core optimization loop. This method is called on each iteration.
Two quick protocols that one needs to follow are:
1. Do NOT actually backpropagate or step() the optimizers if the trainer is not training. Use the
is_training()
method to find out. This is essential since this will ensure that the trainer behaves as expected whenis_training()
isFalse
. Useful, for example, in cases likecallbacks.ColdStart
2. Send a callback the signal
'gradient'
with a keyword argument'models'
that is the list of models that accumulate a gradient. Usually, it’s all the modules (self.modules
).Any callbacks that listen to this signal are interested in the gradient information (eg.
callbacks.Babysitter
).
-
train
(dataloader, epochs=1, callbacks=None, **kwargs)[source]¶ Starts the training process.
Parameters: Keyword Arguments: iterations (int) – The number of iterations to train for. Overrides
epochs
.Note
PyTorch
DataLoader
s are not supported.Ideally, encapsulate your dataset in the
Data
class.
-
mock
(path=None)[source]¶ A context manager that creates a temporary ‘safe’ scope for training.
All impact to stateful objects (models, optimizers and the trainer itself) are forgotten once out of this scope.
This is very useful if you need to try out what-if experiments.
Parameters: path (pathlib.Path) – The path to save temporary states into Default: {System temp directory}/.mock_trainer
- models (list of
SupervisedTrainer¶
-
class
magnet.training.
SupervisedTrainer
(model, optimizer='adam', loss='cross_entropy', metrics=None)[source]¶ A simple trainer that implements a supervised approach where a simple model \(\hat{y} = f(x)\) is trained to map \(\hat{y}\) to ground-truth \(y\) according to some specified loss.
This is the training routine that most high-level deep learning frameworks implement.
Parameters: - model (
nn.Module
) – The model that needs to be trained - optimizer (str or optim.Optimzer) – The optimizer used to train
the model. Default:
'adam'
- loss (str or
callable
) – A loss function that gives the objective to be minimized. Default:'cross_entropy'
- metrics (list) – Any other metrics that need to be monitored.
Default:
None
optimizer
can be an actualoptim.Optimizer
instance or the name of a popular optimzizer (eg.'adam'
).loss
can be a function or the name of a popular loss function (eg.'cross_entropy'
). It should accept 2 arguments (\(\hat{y}\), \(y\)).metrics
should contain a list of functions which accept 2 arguments (\(\hat{y}\), \(y\)), like the loss function.
Note
A static
validate()
function is provided for the validation callbackNote
The
metrics
is of no use unless there is some callback (eg.``callbacks.Monitor``) to receive the metricsExamples:
>>> import magnet as mag >>> import magnet.nodes as mn >>> from magnet.data import Data >>> from magnet.training import callbacks, SupervisedTrainer >>> data = Data.get('mnist') >>> model = mn.Linear(10, act=None) >>> model.build(x=next(data())[0]) >>> trainer = SupervisedTrainer(model) >>> callbacks=[callbacks.Monitor(), callbacks.Validate(data(64, mode='val'), SupervisedTrainer.validate)] >>> trainer.train(data(64, shuffle=True), 1, callbacks)
- model (
-
magnet.training.
finish_training
(path, names=None)[source]¶ A helper function for cleaning up the training logs and other checkpoints and retaining only the state_dicts of the trained models.
Parameters: - path (pathlib.Path) – The path where the trainer was checkpointed
- names (list) – The names of the models in the order given to the trainer.
Default:
None
names
can be used if the models themselves did not have names prior to training. The checkpoints default to an ordered naming scheme. If passed, the files are additionally renamed to these names.
Note
Does nothing / fails silently if the path does not exist.
Example:
>>> # Assume that we've defined two models - encoder and decoder, >>> # and a suitable trainer. The models do not have a 'name' attribute. >>> trainer.save_state(checkpoint_path / 'my-trainer') >>> # Suppose the checkpoint directory contains the following files: >>> # my-trainer/ >>> # models/ >>> # 0.pt >>> # 1.pt >>> # callbacks/ >>> # monitor/ >>> # babysitter/ >>> # state.p >>> finish_training(path, names=['encoder', 'decoder']) >>> # Now the directory contains these files: >>> # encoder.pt >>> # decoder.pt
magnet.training.callbacks¶
CallbackQueue¶
-
class
magnet.training.callbacks.
CallbackQueue
[source]¶ A container for multiple callbacks that can be called in parallel.
If multiple callbacks need to be called together (as intended), they can be registered via this class.
Since callbacks need to be unique (by their name), this class also ensures that there are no duplicates.
-
__call__
(signal, *args, **kwargs)[source]¶ Broadcasts a signal to all registered callbacks along with payload arguments.
Parameters: signal (object) – Any object that is broadcast as a signal. Note
Any other arguments will be sent as-is to the callbacks.
-
find
(name)[source]¶ Scans through the registered list and finds the callback with
name
.If not found, returns None.
Raises: RuntimeError
– If multiple callbacks are found.
-
Monitor¶
-
class
magnet.training.callbacks.
Monitor
(frequency=10, show_progress=True, **kwargs)[source]¶ Allows easy monitoring of the training process.
Stores any metric / quantity broadcast using the
'write_stats'
signal.Also adds a nice progress bar!
Parameters: Keyword Arguments: name (str) – Name of this callback. Default:
'monitor'
frequency
is useful only if there are buffered metrics.
Examples:
>>> import torch >>> import magnet as mag >>> import magnet.nodes as mn >>> from magnet.training import callbacks, SupervisedTrainer >>> model = mn.Linear(10, act=None) >>> with mag.eval(model): model(torch.randn(4, 1, 28, 28)) >>> trainer = SupervisedTrainer(model) >>> callbacks = callbacks.CallbackQueue([callbacks.Monitor()]) >>> callbacks(signal='write_stats', trainer=trainer, key='loss', value=0.1) >>> callbacks[0].history {'loss': [{'val': 0.1}]}
-
__call__
(trainer, signal, **kwargs)[source]¶ Responds to the following signals:
'write_stats'
: Any keyword arguments will be passed to theHistory.append()
method.'on_training_start'
: To be called before start of training. Initializes the progress bar.'on_batch_start'
: Called before the training loop. Updates the progress bar.'on_batch_end'
: Called after the training loop. Flushes the history buffer if needed and sets the progress bar description.'on_training_end'
: To be called after training. Closes the progress bar.'load_state'
: Loads the state of this callback frompath
.'save_state'
: Saves the state of this callback topath
.
Validate¶
-
class
magnet.training.callbacks.
Validate
(dataloader, validate, frequency=10, batches=None, drop_last=False, **kwargs)[source]¶ Runs a validation function over a dataset during the course of training.
Most Machine Learning research uses a held out validation set as a proxy for the test set / real-life data. Hyperparameters are usually tuned on the validation set.
Often, this is done during training in order to view the simultaneous learning on the validation set and catch any overfitting / underfitting.
This callback enables you to run a custom
validate
function over adataloader
.Parameters: - dataloader (
DataLoader
) – DataLoader containing the validation set - validate (bool) – A callable that does the validation
- frequency (int) – Then number of times per epoch to run the function. Default: \(10\)
- batches (int or None) – The number of times / batches to call the validate
function in each run. Default:
None
- drop_last (bool) – If
True
, the last batch is not run. Default:False
Keyword Arguments: name (str) – Name of this callback. Default:
'validate'
validate
is a function which takes two arguments: (trainer, dataloader).batches
defaults to a value which ensures that an epoch of the validation set matches an epoch of the training set.For instance, if the training set has \(80\) datapoints and the validation set has \(20\) and the batch size is \(1\) for both, an epoch consists of \(80\) iterations for the training set and \(20\) for the validation set.
If the validate function is run \(10\) times(
frequency
) per epoch of the training set, thenbatches
must be \(2\).
-
__call__
(trainer, signal, **kwargs)[source]¶ Responds to the following signals:
'on_training_start'
: To be called before start of training. Automatically finds the number of batches per run.'on_batch_end'
: Called after the training loop. Calls thevalidate
function.'on_training_end'
: To be called after training. Ifdrop_last
, calls thevalidate
function.'load_state'
: Loads the state of this callback frompath
.'save_state'
: Saves the state of this callback topath
.
- dataloader (
Checkpoint¶
-
class
magnet.training.callbacks.
Checkpoint
(path, interval='5 m', **kwargs)[source]¶ Serializes stateful objects during the training process.
For many practical Deep Learning projects, training takes many hours, even days.
As such, it is only natural that you’d want to save the progress every once in a while.
This callback saves the models, optimizers, schedulers and the trainer itself periodically and automatically loads from those states if found.
Parameters: - path (pathlib.Path) – The root path to save to
- interval (str) – The time between checkpoints. Default: ‘5 m’
Keyword Arguments: name (str) – Name of this callback. Default:
'checkpoint'
interval
should be a string of the form'{duration} {unit}'
. Valid units are:'us'
(microseconds),'ms'
(milliseconds),'s'
(seconds),'m'
(minutes)’,'h'
(hours),'d'
(days).
-
__call__
(trainer, signal, **kwargs)[source]¶ Responds to the following signals:
'on_training_start'
: To be called before start of training. Creates the path if it doesn’t exist and loads from it if it does. Also sets the starting time.'on_batch_end'
: Called after the training loop. Checkpoints if the interval is crossed and resets the clock.'on_training_end'
: To be called after training. Checkpoints one last time.'load_state'
: Loads the state of this callback frompath
.'save_state'
: Saves the state of this callback topath
.
ColdStart¶
-
class
magnet.training.callbacks.
ColdStart
(epochs=0.1, **kwargs)[source]¶ Starts the trainer in
eval
mode for a few iterations.Sometimes, you may want to find out how the model performs prior to any training. This callback freezes the training initially.
Parameters: epochs (float) – The number of epochs to freeze the trainer. Default: \(0.1\) Keyword Arguments: name (str) – Name of this callback. Default: 'coldstart'
-
__call__
(trainer, signal, **kwargs)[source]¶ Responds to the following signals:
'on_training_start'
: To be called before start of training. Sets the models ineval
mode.'on_batch_end'
: Called after the training loop. If theepochs
is exhausted, unfreezes the trainer and removes this callback from the queue.
-
magnet.training.history¶
-
class
magnet.training.history.
History
[source]¶ A dictionary-like repository which is used to store several metrics of interest in training in the form of snapshots.
This object can be utilized to collect, store and analyze training metrics against a variety of features of interest (epochs, iterations, time etc.)
Since this is a subclass of
dict
, it can be used as such. However, it is prefered to operate it using the class-specific methods.Examples:
>>> history = History() >>> # Add a simple value with a time stamp. >>> # This is like the statement: history['loss'] = 69 >>> # However, any additional stamps can also be attached. >>> history.append('loss', 69, time=time()) {'loss': [{'val': 69, 'time': 1535095251.6717412}]} >>> history.clear() >>> # Use a small buffer-size of 10. >>> # This means that only the latest 10 values are kept. >>> for i in range(100): history.append('loss', i, buffer_size=10) >>> # Flush the buffer with a time stamp. >>> history.flush(time=time()) >>> # The mean of the last 10 values is now stored. {'loss': [{'val': 94.5, 'time': 1535095320.9745226}]}
-
find
(key)[source]¶ A helper method that returns a filtered dictionary with a search key.
Parameters: key (str) – The filter key Examples:
>>> # Assume the history is empty with keys: ['loss', 'val_loss', >>> # 'encoder_loss', 'accuracy', 'wierd-metric'] >>> history.find('loss') {'loss': [], 'val_loss': [], 'encoder_loss': []}
-
append
(key, value, validation=False, buffer_size=None, **stamps)[source]¶ Append a new snapshot to the history.
Parameters: validation
is just a convinient key-modifier. It appends'val_'
to the key.buffer_size
defines the size of the storage buffer for the specifickey
.The latest
buffer_size
snapshots are stored.If None, the
key
is stored as is.
Note
Any further keyword arguments define
stamps
that are essentially the signatures for the snapshot.
-
show
(key=None, log=False, x_key=None, validation=True, legend=None, **kwargs)[source]¶ Plot the snapshots for a key against a stamp.
Parameters: - key (str) – The key of the record
- log (bool) – If
True
, the y-axis will be log-scaled. Default:False
- x_key (str or None) – The stamp to use as the x-axis.
Default:
None
- validation (bool) – Whether to plot the validation records
(if they exist) as well. Default:
True
- legend (str or None) – The legend entry. Default:
None
Keyword Arguments: - ax (pyplot axes object) – The axis to plot into. Default:
None
- smoothen (bool) – If
True
, smoothens the plot. Default:True
- window_fraction (float) – How much of the plot to use as a window for smoothing. Default: \(0.3\)
- gain (float) – How much more dense to make the plot. Default: \(10\)
- replace_outliers (bool) – If
True
, replaces outlier datapoints by a sensible value. Default:True
key
can beNone
, in which case this method is successively called for all existing keys. Thelog
attribute is overriden, however. It is only set toTrue
for any key with'loss'
in it.legend
can beNone
, in which case the default legends'training'
and'validation'
are applied respectively.
-
flush
(key=None, **stamps)[source]¶ Flush the buffer (if exists) and append the mean.
Parameters: key (str or None) – The key to flush. Default: None
key
can be None, in which case this method is successively called for all existing keys.
Note
Any further keyword arguments define
stamps
that are essentially the signatures for the snapshot.
-
-
class
magnet.training.history.
SnapShot
(buffer_size=-1)[source]¶ A list of stamped values (snapshots).
This is used by the History object to store a repository of training metrics.
Parameters: buffer_size (int) – The size of the buffer. Default: \(-1\) - If
buffer_size
is negative, then the snapshots are stored as is.
-
append
(value, buffer=False, **stamps)[source]¶ Add a new snapshot.
Parameters: Note
Any further keyword arguments define
stamps
that are essentially the signatures for the snapshot.
-
flush
(**stamps)[source]¶ Flush the buffer (if exists) and append the mean.
Note
Any keyword arguments define
stamps
that are essentially the signatures for the snapshot.
-
show
(ax, x=None, label=None, **kwargs)[source]¶ Plot the snapshots against a stamp.
Parameters: key
can be None, in which case this method is successively called for all existing keys. Thelog
attribute is overriden, however. It is only set toTrue
for any key with'loss'
in it.legend
can beNone
, in which case the default legends'training'
and'validation'
are applied respectively.
Keyword Arguments: () – See History.show()
for more details.Note
Any further keyword arguments are passed to the plot function.
- If
magnet.training.utils¶
-
magnet.training.utils.
load_object
(path, **kwargs)[source]¶ A convinience method to unpickle a file.
Parameters: path (pathlib.Path) – The path to the pickle file Keyword Arguments: default (object) – A default value to be returned if the file does not exist. Default: None
Raises: RuntimeError
– If a default keyword argument is not provided and the file is not found.
-
magnet.training.utils.
load_state
(module, path, alternative_name=None)[source]¶ Loads the state_dict of a PyTorch object from a specified path.
This is a more robust version of the of the PyTorch way in the sense that the device mapping is automatically handled.
Parameters: - module (object) – Any PyTorch object that has a state_dict
- path (pathlib.Path) – The path to folder containing the state_dict file
- alternative_name (str or None) – A fallback name for the file if the
module object does not have a name attribute. Default:
None
Raises: RuntimeError
– If noalternative_name
is provided and the module does not have a name.Note
If you already know the file name, set
alternative_name
to that.This is just a convinience method that assumes that the file name will be the same as the name of the module (if there is one).
-
magnet.training.utils.
save_object
(obj, path)[source]¶ A convinience method to pickle an object.
Parameters: - obj (object) – The object to pickle
- path (pathlib.Path) – The path to save to
Note
If the
path
does not exists, it is created.
-
magnet.training.utils.
save_state
(module, path, alternative_name=None)[source]¶ Saves the state_dict of a PyTorch object to a specified path.
Parameters: - module (object) – Any PyTorch object that has a state_dict
- path (pathlib.Path) – The path to a folder to save the state_dict to
- alternative_name (str or None) – A fallback name for the file if the
module object does not have a name attribute. Default:
None
Raises: RuntimeError
– If noalternative_name
is provided and the module does not have a name.