Source code for magnet.nodes.nodes

# coding=utf-8
import torch
import magnet as mag

from torch import nn

from magnet.utils.misc import caller_locals

[docs]class Node(nn.Module): r"""Abstract base class that defines MagNet's Node implementation. A Node is a *'self-aware Module'*. It can dynamically parametrize itself in runtime. For instance, a ``Linear`` Node can infer the input features automatically when first called; a ``Conv`` Node can infer the dimensionality (1, 2, 3) of the input automatically. MagNet's Nodes strive to help the developer as much as possible by finding the right hyperparameter values automatically. Ideally, the developer shouldn't need to define anything except the basic architecture and the inputs and outputs. The arguments passed to the constructor are stored in a ``_args`` attribute as a dictionary. This is later modified by the :py:meth:`build` method which gets automatically called on the first forward pass. Keyword Args: name (str) - A printable name for this node. Default: Class Name """ def __init__(self, *args, **kwargs): super().__init__() self._parse_args() self._built = False
[docs] def build(self, *args, **kwargs): r"""Builds the Node. Ideally, should not be called manually. When an unbuilt module is first called, this method gets invoked. """ self._built = True self.to(mag.device)
def __call__(self, *args, **kwargs): if not (self._built and mag.build_lock): self.build(*args, **kwargs) return super().__call__(*args, **kwargs) def _parse_args(self): """ A Helper Method to get all the constructor arguments and store them into _args. This will help modify these arguments at runtime. Additionally, this method also captures the name of the Node, if given (default is the class name). """ args = caller_locals(ancestor=True) args.update(args.pop('kwargs', {})) self.name = args.pop('name', self.__class__.__name__) if self.name is None or self.name == '': raise ValueError(f"""One of the {self.__class__.__name__} Node's names is {self.name}""") self._args = args def get_args(self): """ Returns a nicely formatted string describing the argumens """ return ', '.join(str(k) + '=' + str(v) for k, v in self._args.items()) def to(self, *args, **kwargs): super().to(*args, **kwargs) # Additionally, set a convinient device attribute try: self.device = next(self.parameters())[0].device except StopIteration: pass return self def load_state_dict(self, f): from pathlib import Path # Handle a path being given instead of a file. (preferred since it # automatically maps to the correct device) if isinstance(f, (str, Path)): device = self.device.type if device == 'cuda': device = 'cuda:0' return super().load_state_dict(torch.load(f, map_location=device)) else: return super().load_state_dict(f) def _mul_int(self, n): return [self] + [self.__class__(**self._args) for _ in range(n - 1)]
[docs] def _mul_list(self, n): r"""A useful overload of the * operator that can create similar copies of the node. Args: n (tuple or list) - The modifier supplied The modifier n should be used to change the arguments of the node in a meaningful way. For instance, in the case of a Linear node, the items in n can be interpreted as the output dimensions of each layer. """ raise NotImplementedError
def __mul__(self, n): if isinstance(n, int) or (isinstance(n, float) and n.is_integer()): return self._mul_int(n) if isinstance(n, (tuple, list)): return self._mul_list(n)