Source code for magnet._autograd

import torch

from contextlib import contextmanager

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

build_lock = True

[docs]def eval(*modules): r"""A Context Manger that makes it easy to run computations in ``eval`` mode. It sets modules in their ``eval`` mode and ensures that gradients are not computed. This is a more wholesome option than :py:meth:`torch.no_grad` since many Modules (BatchNorm, Dropout etc.) behave differently while training and testing. Examples:: >>> import magnet as mag >>> import magnet.nodes as mn >>> import torch >>> model = mn.Linear(10) >>> x = torch.randn(4, 3) >>> # Using eval() as context manager >>> with mag.eval(model): >>> model(x) >>> # Use as decorator >>> @mag.eval(model) >>> def foo(): >>> return model(x) >>> foo() >>> # The modules can also be given at runtime by specifying no arguments >>> @mag.eval >>> def foo(model): >>> return model(x) >>> foo() >>> # The method then takes modules from the arguments >>> # to the decorated function. """ from inspect import isfunction # Check if called as decorator if not isfunction(modules[0]) or len(modules) > 1: return _eval_context_manager(*modules) from functools import wraps fn = modules[0] # The decorated function @wraps(fn) def new_fn(*args, **kwargs): from torch.nn import Module arg_list = list(args) + list(kwargs.values()) modules = [a for a in arg_list if isinstance(a, Module)] with _eval_context_manager(*modules): return fn(*args, **kwargs) return new_fn
@contextmanager def _eval_context_manager(*modules): states = [] modules = [module for module in modules if module.training] for module in modules: states.append(module.training) module.eval() with torch.no_grad(): try: yield finally: for module, state in zip(modules, states): module.train(state)