Posted: Nov 18, 2024
A hybrid design: using nn.Modules with @dataclasses
About this tutorial
Python’s dataclasses and PyTorch’s nn.Modules are rarely used together.
The communities using these features might just have too little of an overlap, but I suspect the main reason is likely that combining them isn’t quite straightforward.
In this tutorial, we’ll explore some of the design and inner workings of nn.Module, and some rules about Python classes and dataclasses.
Along the way, we’ll encounter a few pitfalls — and learn how to overcome them.
By the end of this tutorial, we’ll have implemented a small hybrid class that integrates both.
The format of this tutorial
scrolling
This post is also an experiment on a new tutorial format.
As you scroll, you will see code being written live in a virtual terminal on the right (or at the bottom if you are on mobile)
Scrolling works a bit differently on this site. You can scroll within this text or on the right side of the virtual terminal (sorry, lefties).
typing and navigating
If you like typing tests, and want to practice a bit, you can follow along with the tutorial by typing exactly as shown on the key press overlay below the terminal. The content will automatically scroll with your progress.
You can also click or tap on a letter on the overlay to jump to that specific point in time.
Pressing the left arrow key will take you one step back.
tmux
The virtual terminal is running tmux for multiplexing, so you will see some key-combinations that switch between the code editor and the terminal used for running code.
neovim
Neovim is used for code editing, and its normal mode commands frequently appear throughout the tutorial.
Understanding these could be challenging if you’re not already familiar with Vim motions. Feel free to skip them or try to figure out how they work and learn a bit of Vim in practice. However, this tutorial isn’t focused on teaching Vim motions — consider this your warning.
Some neovim plugins will also appear at certain points, these will have some unique keybindings too.
A Point in 2D
Let’s look at a typical example. In modern Python, we can create a simple class to represent a 2D point.
Let’s run this code
The output we get is the representation of our point object.
<__main__.Point object at 0x716bd3f19e90> This is a bit opaque — it only tells us the type and the memory address of the object.
__repr__ for transparency
Often, we’re more interested in the contents of an object rather than its identity.
To get a more human-friendly representation, we can override the __repr__ method of the class.
This will make the output clearer and more informative, making our lives a bit easier when debugging.
The new representation of a point will look like this:
Point(x=3, y=4)equality
For fun, we can check if two points are equal
p = Point(3, 4)
q = Point(5, 12)
(p == q) is False This seems intuitive so far, but what if we compare two point objects that we expect to be equal because they have the same coordinates?
(Point(3,4) == Point(3,4)) is FalseBy default, equality in python is based on the identity of the objects, so a point will only be equal to itself
(p == p) is True To override this behaviour, let’s create a custom __eq__ method, where we compare the coordinates
(Point(3,4) == Point(3,4)) is TrueThe output is now more in line to what we would expect when working with Points.
Point(x=3, y=4) Point(x=5, y=12) Point(x=3, y=4)
True False TrueEliminating boilerplate with @dataclass
A fairly simple example so far, yet we have written quite a bit of code already, only to achieve what could be considered a reasonable default.
This is where dataclasses come into play. We can get the same exact functionality with far fewer lines of code.
The @dataclass decorator automatically generates the __init__, __repr__, and __eq__ methods for the class, making our code much more concise while maintaining the same functionality.
By running the code again, we get the same output as before
Point(3, 4) Point(5, 12) Point(3, 4)
True False FalseThis simplification also clarifies the intent of the code. By not explicitly defining the __init__, __repr__, and __eq__ methods, we’re expressing that we want the default, most straightforward implementation of those.
Adding more methods works as usual.
For example, a p-norm could look like this
output
The Pythagorean triples are now complete with this output
...
5.0
13.0Weightlifting with nn.Modules
A simple model
Now let’s look at a simple custom model definition using PyTorch.
For a usual model we need 4 core elements:
- subclassing
nn.Module - calling
super().__init__at the beginning of__init__ - defining the building blocks of the model (usually other
nn.Modulesubclasses) - implementing a
forwardmethod (this is where we define how the data flows and transforms through our model)
about init
The __init__ looks simple, but there is in fact a lot happening behind the scenes.
The super().__init__() call is crucial here, it sets up the internal mechanisms of nn.Module.
From the pytorch documentation:
an
__init__()call to the parent class must be made before assignment on the child.
so in our example, super().__init__() has to happen before self.linear = ...
We now have but one choice: We must face the long dark of PyTorch
nn.Module has its own __setattr__ implementation, so the self.linear = ... and self.relu = ... assignments aren’t as simple as they appear.
To understand how the internal state management works, we can take a look at the source code of nn.Module.__init__
The comment here captures the main idea:
"""
Calls `super().__setattr__('a', a)` instead of the typical `self.a = a`
to avoid `Module.__setattr__` overhead. Module's `__setattr__` has special
handling for `parameter`s, `submodule`s, and `buffer`s but simply calls into
`super().__setattr__` for all other attributes.
"""internal dictionaries for modules
Overly simplified, if we are assigning an nn.Module as an attribute like this:
self.linear = nn.Linear(...) the special handling in __setattr__ will see that the value we are assigning has the type nn.Linear (which is a subclass of nn.Module),
and will store it in the _modules dictionary
modules['linear'] = nn.Linear(...) parameters and buffers are also handled similarly.
This internal state management enables for more ergonomic code, for example we can just call .to on a Module, and it will move all of its parameters (including the parameters of its submodules) to the given device
__repr__
nn.Module has its own __repr__ implementation, that makes use of the _modules dictionary to include the building blocks of the model in its printed representation.
SingleLayer(
(linear): Linear(in_features=1764, out_features=512, bias=True)
(relu): ReLU()
)A more flexible NeuralNetwork
Let’s create another NeuralNetwork, now with a few hyperparameters: we create parameters for the input-, hidden- and output layer dimensions as well as the number of hidden layers.
And we can create the actual building blocks of the network: here they are Linear layers with ReLU activations.
We pass these to Sequential to get our net, and we also define the loss to be the Mean Squared Error.
The forward is really simple: we just pass the input x to net, and calculate the loss, which is just the mean squared error between the prediction y and the target.
In the __main__, we simply instantiate the model, and print it.
Let’s see the output of this:
NeuralNetwork(
(net): Sequential(
(0): Linear(in_features=4096, out_features=42, bias=True)
(1): ReLU()
(2): Linear(in_features=42, out_features=42, bias=True)
(3): ReLU()
(4): Linear(in_features=42, out_features=42, bias=True)
(5): ReLU()
(6): Linear(in_features=42, out_features=1, bias=True)
(7): ReLU()
)
(loss): MSELoss()
)first attempt at @dataclass
There’s quite a bit of boilerplate in this class,
we can again try to eliminate some of it with the use of @dataclass.
To keep the generated __init__ intact, we can create net and loss in the __post_init__ method. __post_init__ is automatically called at the end of the @dataclass generated __init__.
Next, we can also define the attributes created in __post_init__ as fields in the class.
By marking these with init=False, they become excluded from the __init__ method, and we become responsible for initializing them elsewhere, such as in the __post_init__.
The code is now a bit less repetitive, and the intent is also more clear.
As an added benefit, we get a separation of the attributes that are passed to the class, and the ones that are created during the initialization of the class: net, and loss can be thought of as owned by the NeuralNetwork.
Not so superb init
The code in its current state suffers from a few errors, so let’s try to fix the first one
AttributeError: cannot assign module before Module.__init__() callWe forgot to call super().__init__. This has to happen somewhere before assigning a submodule like self.net = Sequential(*layers).
Naively, we could try to place this in the __post_init__ too:
def __post_init__(self):
super().__init__()this seems to work:
NeuralNetwork(x_dim=4096, h_dim=42, y_dim=1, num_hidden_layers=3) But we have to keep in mind that super().__init__ is supposed to be called at the beginning of initialization, otherwise we might run into more issues.
For example, let’s try to do a small refactor, since MSELoss doesn’t depend on any of the input arguments of __init__, we could pass a default_factory to it’s field.
And now we get the same error again
AttributeError: cannot assign module before Module.__init__() callWe need a super().__init__ call before assigning PyTorch-specific values like an instance of MSELoss (which itself is a subclass of nn.Module).
But the default_factory is called in __init__, and there is an attempt to store this as an attribute like so
self.loss = MSELoss() super().__init__ is only called after this, since we placed it in the __post_init__ method.
Ideally super().__init__() should happen at the very beginning of our __init__ , but we would also like to rely on the @dataclass generated __init__, so we don’t have to write the tedious self.x = x lines.
Unfortunately there is no corresponding __pre_init__ function that would get called before __init__, when using dataclasses.
It is possible to hack something together with __new__ (since this happens before __init__), but let’s instead look into another library to help us.
switching libraries: attrs
The attrs library has the same simplifying power as the built-in dataclasses, and in some cases it’s even more powerful (https://www.attrs.org/en/stable/why.html#why-not)
One of those cases is if we want to hook into initialization (https://www.attrs.org/en/stable/init.html#hooking-yourself-into-initialization)
So first let’s install the library:
Next we can replace the previous imports from dataclasses with attrs
And we can use __attrs_pre_init__ to call torch.nn.Module.__init__.
Now this seems to work, there are no more visible errors in the class
NeuralNetwork(x_dim=4096, h_dim=42, y_dim=1, num_hidden_layers=3)Hashenanigans
Let’s try to do the very first step of a usual training setup, which would be simply passing model.parameters() to an optimizer.
The error
TypeError: unhashable type: 'NeuralNetwork' indicates that nn.Module is expected to be hashable, which it is not right now.
The cause for this lies in how attrs works. It seems like somehow using @define makes the class unhashable.
There is a great explanation on why that happens in the documentation: (https://www.attrs.org/en/stable/hashing.html)[https://www.attrs.org/en/stable/hashing.html]
Basically, if there is a custom __eq__ implementation in our class (and by default attrs creates this), python will make it unhashable.
An easy fix is adding @define(eq=False), so that attrs doesn’t create an __eq__ for us (and in our case comparing models directly doesn’t really make sense anyways).
There’s no error anymore, and finally our hybridization is complete.
One last potential bug
You should not initialize an optional attribute with
None.
This is a subtle bug, one that is not specific to our current setup, it could also occur in a usual custom pytorch model. However, because of the current setup, there is a bigger chance that you would face this.
Since if we first assign e.g. loss=None, this will result in a super().__setattr__('loss', None) call from nn.Module.__setattr__, setting an attribute on the super class.
If we then try to reassign another value to loss, even though internally this new value get’s registered in the _modules dictionary, when trying to access model.loss later, we will always get back None, because object.loss has priority in python over nn.Module.__getattr__.
So don’t do this:
...
loss: MSELoss = field(default=None)
def __attrs_post_init__(self):
...
self.loss = MSELoss()
if __name__ == '__main__':
...
print(nn.loss is None) And even though we re-assigned self.loss, we would still get the output
...
TrueDrawbacks of the hybrid class
Compared to the vanilla version of our model, the new pattern does introduce a bit of overhead, and you need to understand some intrinsics of both nn.Module and dataclasses.
Benefits of this pattern
A nicer debugging experience
(print-)debugging a model will now produce a dataclass flavoured output, and you can also more easily fine tune this by using field(repr=False) to hide uninteresting parts of the model when printing.
Division between passed arguments and object managed arguments
This mostly goes for the intent of the code. Usually there is a clear division between arguments that are provided as input to the __init__ (but we might want to save these for later use, so they also become part of the object), and things that are created inside __post_init__ (these are completely “owned” by the object).
For complex models this could make the code a bit easier to understand.