Posted: Nov 18, 2024
A hybrid design: using nn.Modules with @dataclasses
About this tutorial
Python’s dataclass
es and PyTorch’s nn.Module
s are rarely used together.
The communities using these features might just have too little of an overlap, but I suspect the main reason is likely that combining them isn’t quite straightforward.
In this tutorial, we’ll explore some of the design and inner workings of nn.Module
, and some rules about Python classes and dataclass
es.
Along the way, we’ll encounter a few pitfalls — and learn how to overcome them.
By the end of this tutorial, we’ll have implemented a small hybrid class that integrates both.
The format of this tutorial
scrolling
This post is also an experiment on a new tutorial format.
As you scroll, you will see code being written live in a virtual terminal on the right (or at the bottom if you are on mobile)
Scrolling works a bit differently on this site. You can scroll within this text or on the right side of the virtual terminal (sorry, lefties).
typing and navigating
If you like typing tests, and want to practice a bit, you can follow along with the tutorial by typing exactly as shown on the key press overlay below the terminal. The content will automatically scroll with your progress.
You can also click or tap on a letter on the overlay to jump to that specific point in time.
Pressing the left arrow key will take you one step back.
tmux
The virtual terminal is running tmux for multiplexing, so you will see some key-combinations that switch between the code editor and the terminal used for running code.
neovim
Neovim is used for code editing, and its normal mode commands frequently appear throughout the tutorial.
Understanding these could be challenging if you’re not already familiar with Vim motions. Feel free to skip them or try to figure out how they work and learn a bit of Vim in practice. However, this tutorial isn’t focused on teaching Vim motions — consider this your warning.
Some neovim plugins will also appear at certain points, these will have some unique keybindings too.
A Point in 2D
Let’s look at a typical example. In modern Python, we can create a simple class to represent a 2D point.
Let’s run this code
The output we get is the representation of our point object.
<__main__.Point object at 0x716bd3f19e90>
This is a bit opaque — it only tells us the type and the memory address of the object.
__repr__
for transparency
Often, we’re more interested in the contents of an object rather than its identity.
To get a more human-friendly representation, we can override the __repr__
method of the class.
This will make the output clearer and more informative, making our lives a bit easier when debugging.
The new representation of a point will look like this:
Point(x=3, y=4)
equality
For fun, we can check if two points are equal
p = Point(3, 4)
q = Point(5, 12)
(p == q) is False
This seems intuitive so far, but what if we compare two point objects that we expect to be equal because they have the same coordinates?
(Point(3,4) == Point(3,4)) is False
By default, equality in python is based on the identity of the objects, so a point will only be equal to itself
(p == p) is True
To override this behaviour, let’s create a custom __eq__
method, where we compare the coordinates
(Point(3,4) == Point(3,4)) is True
The output is now more in line to what we would expect when working with Point
s.
Point(x=3, y=4) Point(x=5, y=12) Point(x=3, y=4)
True False True
Eliminating boilerplate with @dataclass
A fairly simple example so far, yet we have written quite a bit of code already, only to achieve what could be considered a reasonable default.
This is where dataclass
es come into play. We can get the same exact functionality with far fewer lines of code.
The @dataclass
decorator automatically generates the __init__
, __repr__
, and __eq__
methods for the class, making our code much more concise while maintaining the same functionality.
By running the code again, we get the same output as before
Point(3, 4) Point(5, 12) Point(3, 4)
True False False
This simplification also clarifies the intent of the code. By not explicitly defining the __init__
, __repr__
, and __eq__
methods, we’re expressing that we want the default, most straightforward implementation of those.
Adding more methods works as usual.
For example, a p-norm
could look like this
output
The Pythagorean triples are now complete with this output
...
5.0
13.0
Weightlifting with nn.Module
s
A simple model
Now let’s look at a simple custom model definition using PyTorch.
For a usual model we need 4 core elements:
- subclassing
nn.Module
- calling
super().__init__
at the beginning of__init__
- defining the building blocks of the model (usually other
nn.Module
subclasses) - implementing a
forward
method (this is where we define how the data flows and transforms through our model)
about init
The __init__
looks simple, but there is in fact a lot happening behind the scenes.
The super().__init__()
call is crucial here, it sets up the internal mechanisms of nn.Module
.
From the pytorch documentation:
an
__init__()
call to the parent class must be made before assignment on the child.
so in our example, super().__init__()
has to happen before self.linear = ...
We now have but one choice: We must face the long dark of PyTorch
nn.Module
has its own __setattr__
implementation, so the self.linear = ...
and self.relu = ...
assignments aren’t as simple as they appear.
To understand how the internal state management works, we can take a look at the source code of nn.Module.__init__
The comment here captures the main idea:
"""
Calls `super().__setattr__('a', a)` instead of the typical `self.a = a`
to avoid `Module.__setattr__` overhead. Module's `__setattr__` has special
handling for `parameter`s, `submodule`s, and `buffer`s but simply calls into
`super().__setattr__` for all other attributes.
"""
internal dictionaries for modules
Overly simplified, if we are assigning an nn.Module
as an attribute like this:
self.linear = nn.Linear(...)
the special handling in __setattr__
will see that the value we are assigning has the type nn.Linear
(which is a subclass of nn.Module
),
and will store it in the _modules
dictionary
modules['linear'] = nn.Linear(...)
parameter
s and buffer
s are also handled similarly.
This internal state management enables for more ergonomic code, for example we can just call .to
on a Module, and it will move all of its parameters (including the parameters of its submodules) to the given device
__repr__
nn.Module
has its own __repr__
implementation, that makes use of the _modules
dictionary to include the building blocks of the model in its printed representation.
SingleLayer(
(linear): Linear(in_features=1764, out_features=512, bias=True)
(relu): ReLU()
)
A more flexible NeuralNetwork
Let’s create another NeuralNetwork
, now with a few hyperparameters: we create parameters for the input-, hidden- and output layer dimensions as well as the number of hidden layers.
And we can create the actual building blocks of the network: here they are Linear
layers with ReLU
activations.
We pass these to Sequential
to get our net
, and we also define the loss
to be the M
ean S
quared E
rror.
The forward
is really simple: we just pass the input x
to net
, and calculate the loss, which is just the mean squared error between the prediction y
and the target
.
In the __main__
, we simply instantiate the model, and print
it.
Let’s see the output of this:
NeuralNetwork(
(net): Sequential(
(0): Linear(in_features=4096, out_features=42, bias=True)
(1): ReLU()
(2): Linear(in_features=42, out_features=42, bias=True)
(3): ReLU()
(4): Linear(in_features=42, out_features=42, bias=True)
(5): ReLU()
(6): Linear(in_features=42, out_features=1, bias=True)
(7): ReLU()
)
(loss): MSELoss()
)
first attempt at @dataclass
There’s quite a bit of boilerplate in this class,
we can again try to eliminate some of it with the use of @dataclass
.
To keep the generated __init__
intact, we can create net
and loss
in the __post_init__
method. __post_init__
is automatically called at the end of the @dataclass
generated __init__
.
Next, we can also define the attributes created in __post_init__
as fields in the class.
By marking these with init=False
, they become excluded from the __init__
method, and we become responsible for initializing them elsewhere, such as in the __post_init__
.
The code is now a bit less repetitive, and the intent is also more clear.
As an added benefit, we get a separation of the attributes that are passed to the class, and the ones that are created during the initialization of the class: net
, and loss
can be thought of as owned by the NeuralNetwork
.
Not so superb init
The code in its current state suffers from a few errors, so let’s try to fix the first one
AttributeError: cannot assign module before Module.__init__() call
We forgot to call super().__init__
. This has to happen somewhere before assigning a submodule like self.net = Sequential(*layers)
.
Naively, we could try to place this in the __post_init__
too:
def __post_init__(self):
super().__init__()
this seems to work:
NeuralNetwork(x_dim=4096, h_dim=42, y_dim=1, num_hidden_layers=3)
But we have to keep in mind that super().__init__
is supposed to be called at the beginning of initialization, otherwise we might run into more issues.
For example, let’s try to do a small refactor, since MSELoss
doesn’t depend on any of the input arguments of __init__
, we could pass a default_factory
to it’s field
.
And now we get the same error again
AttributeError: cannot assign module before Module.__init__() call
We need a super().__init__
call before assigning PyTorch-specific values like an instance of MSELoss
(which itself is a subclass of nn.Module
).
But the default_factory
is called in __init__
, and there is an attempt to store this as an attribute like so
self.loss = MSELoss()
super().__init__
is only called after this, since we placed it in the __post_init__
method.
Ideally super().__init__()
should happen at the very beginning of our __init__
, but we would also like to rely on the @dataclass
generated __init__
, so we don’t have to write the tedious self.x = x
lines.
Unfortunately there is no corresponding __pre_init__
function that would get called before __init__
, when using dataclass
es.
It is possible to hack something together with __new__
(since this happens before __init__
), but let’s instead look into another library to help us.
switching libraries: attrs
The attrs library has the same simplifying power as the built-in dataclasses, and in some cases it’s even more powerful (https://www.attrs.org/en/stable/why.html#why-not)
One of those cases is if we want to hook into initialization (https://www.attrs.org/en/stable/init.html#hooking-yourself-into-initialization)
So first let’s install the library:
Next we can replace the previous imports from dataclasses
with attrs
And we can use __attrs_pre_init__
to call torch.nn.Module.__init__
.
Now this seems to work, there are no more visible errors in the class
NeuralNetwork(x_dim=4096, h_dim=42, y_dim=1, num_hidden_layers=3)
Hashenanigans
Let’s try to do the very first step of a usual training setup, which would be simply passing model.parameters()
to an optimizer.
The error
TypeError: unhashable type: 'NeuralNetwork'
indicates that nn.Module
is expected to be hashable, which it is not right now.
The cause for this lies in how attrs
works. It seems like somehow using @define
makes the class unhashable.
There is a great explanation on why that happens in the documentation: (https://www.attrs.org/en/stable/hashing.html)[https://www.attrs.org/en/stable/hashing.html]
Basically, if there is a custom __eq__
implementation in our class (and by default attrs
creates this), python will make it unhashable.
An easy fix is adding @define(eq=False)
, so that attrs
doesn’t create an __eq__
for us (and in our case comparing models directly doesn’t really make sense anyways).
There’s no error anymore, and finally our hybridization is complete.
One last potential bug
You should not initialize an optional attribute with
None
.
This is a subtle bug, one that is not specific to our current setup, it could also occur in a usual custom pytorch model. However, because of the current setup, there is a bigger chance that you would face this.
Since if we first assign e.g. loss=None
, this will result in a super().__setattr__('loss', None)
call from nn.Module.__setattr__
, setting an attribute on the super class.
If we then try to reassign another value to loss
, even though internally this new value get’s registered in the _modules
dictionary, when trying to access model.loss
later, we will always get back None
, because object.loss
has priority in python over nn.Module.__getattr__
.
So don’t do this:
...
loss: MSELoss = field(default=None)
def __attrs_post_init__(self):
...
self.loss = MSELoss()
if __name__ == '__main__':
...
print(nn.loss is None)
And even though we re-assigned self.loss, we would still get the output
...
True
Drawbacks of the hybrid class
Compared to the vanilla version of our model, the new pattern does introduce a bit of overhead, and you need to understand some intrinsics of both nn.Module
and dataclass
es.
Benefits of this pattern
A nicer debugging experience
(print-)debugging a model will now produce a dataclass flavoured output, and you can also more easily fine tune this by using field(repr=False)
to hide uninteresting parts of the model when print
ing.
Division between passed arguments and object managed arguments
This mostly goes for the intent of the code. Usually there is a clear division between arguments that are provided as input to the __init__
(but we might want to save these for later use, so they also become part of the object), and things that are created inside __post_init__
(these are completely “owned” by the object).
For complex models this could make the code a bit easier to understand.