Shortcuts

Source code for nets.nn.modules.module

r"""
Modules are the main architecture for all transformations from one tensor to another. In other words,
a neural network is a succession of modules (layer, convolution, activation...).
When building a custom neural network, your model must inherits from ``Module`` abstract class and override the
``forward`` method. Moreover, you can specify the back-propagation rule in ``backward`` method. Usually, the
``backward`` method computes the naive back-propagation using only local gradients saved in the modules's ``_cache``.
If you don't specify it, **NETS** will uses ``autograd`` functionality to compute all gradients.
"""

from collections import OrderedDict
from abc import ABC, abstractmethod
import inspect
import warnings
import json
import pickle
from nets import Parameter


[docs]class Module(ABC): """ Abstract Module architecture. All models used to transform tensors should extends from this class to benefits ``forward`` and ``backward`` propagation rules. """ def __init__(self): self.training = True self._modules = OrderedDict() self._params = OrderedDict() self._grads = OrderedDict() self._cache = OrderedDict()
[docs] @abstractmethod def forward(self, *inputs): """One forward step. Gradients and outputs should be saved in the ``_cache`` when training, to be able to perform the backward pass. """ raise NotImplementedError
[docs] def backward(self, *outputs): """One backward step.""" raise NotImplementedError
[docs] def train(self): """Set the ``training`` attribute to training mode.""" self.training = True for param in self.parameters(): param.requires_grad = True
[docs] def eval(self): """Set the ``training`` attribute to evaluation mode.""" self.training = False for param in self.parameters(): param.requires_grad = False
[docs] def add(self, *modules): """Add modules to the current one. Args: modules (Module): modules to add """ for module in modules: idx = len(self._modules) name = f"{idx}" setattr(self, name, module) self._modules[name] = module
[docs] def parameters(self): """Iterator through all parameters""" for name, value in inspect.getmembers(self): if isinstance(value, Parameter): yield value elif isinstance(value, Module): yield from value.parameters()
[docs] def modules(self): """Iterator through all gradients""" yield from self._modules.values()
[docs] def cache(self): """Iterator through all cache dict""" for module in self.modules(): yield module._cache
[docs] def gradients(self): """Iterator through all gradients""" for module in self.modules(): yield module._grads
[docs] def zero_grad(self): """Zero grad all parameters within a modules""" for parameter in self.parameters(): parameter.zero_grad()
[docs] def state_dict(self): r"""Save all parameters in a dictionary.""" state = OrderedDict() for i, param in enumerate(self.parameters()): state[f'param{i}'] = param.tolist() return state
[docs] def load_state(self, state_dict): r"""Load parameters from a ``state_dict`` dictionary.""" for i, param in self.parameters(): data = state_dict[f'param{i}'] if param.shape != data.shape: warnings.warn(f"shape from the `state_dict` does not match model's parameter shape. " f"Got {data.shape}, expected {param.shape}.", UserWarning, stacklevel=2) param.data = Parameter(data=data)
[docs] def save(self, filename='model.pickle'): """Save a model as a PICKLE file.""" with open(filename, 'wb') as f: pickle.dump(self, f)
[docs] def save_dict(self, filename='state_dict.json'): """Save the state as a JSON file.""" state = self.state_dict() with open(filename, 'w') as f: json.dump(state, f)
[docs] def get_name(self): """Quick access to get the name of a modules. Returns: string: modules's name """ return self.__class__.__name__
[docs] def inner_repr(self): """Return the representation of a single modules. This method should be unique for each modules. Returns: string: the representation of one modules. """ return ""
def __repr__(self): # Representation similar to PyTorch string = f"{self.get_name()}(" tab = " " modules = self._modules if modules == {}: string += f'\n{tab}(parameters): {self.inner_repr()}' else: for key, module in modules.items(): string += f"\n{tab}({key}): {module.get_name()}({module.inner_repr()})" return f'{string}\n)' def __call__(self, *inputs): return self.forward(*inputs) def __setattr__(self, key, value): # First initialize the attribute we want to add self.__dict__[key] = value # Then update the inner dictionary '_modules', '_params' if isinstance(value, Module): self._modules[key] = value elif isinstance(value, Parameter): self._params[key] = value

Docs

Access comprehensive developer documentation for Nets

View Docs

Tutorials

Get beginners tutorials and create state-of-the-art models

View Tutorials

Resources

Check the GitHub page and contribute to the project

View GitHub