Note:
This is an interactive tutorial, scroll down to begin
#tooling#python#pytorch#language features#machine learning

Posted: Nov 18, 2024

A hybrid design: using nn.Modules with @dataclasses

About this tutorial

Python’s dataclasses and PyTorch’s nn.Modules are rarely used together.

The communities using these features might just have too little of an overlap, but I suspect the main reason is likely that combining them isn’t quite straightforward.

In this tutorial, we’ll explore some of the design and inner workings of nn.Module, and some rules about Python classes and dataclasses. Along the way, we’ll encounter a few pitfalls — and learn how to overcome them.

By the end of this tutorial, we’ll have implemented a small hybrid class that integrates both.

The format of this tutorial

scrolling

This post is also an experiment on a new tutorial format.

As you scroll, you will see code being written live in a virtual terminal on the right (or at the bottom if you are on mobile)

Scrolling works a bit differently on this site. You can scroll within this text or on the right side of the virtual terminal (sorry, lefties).

typing and navigating

If you like typing tests, and want to practice a bit, you can follow along with the tutorial by typing exactly as shown on the key press overlay below the terminal. The content will automatically scroll with your progress.

You can also click or tap on a letter on the overlay to jump to that specific point in time.

Pressing the left arrow key will take you one step back.

tmux

The virtual terminal is running tmux for multiplexing, so you will see some key-combinations that switch between the code editor and the terminal used for running code.

neovim

Neovim is used for code editing, and its normal mode commands frequently appear throughout the tutorial.

Understanding these could be challenging if you’re not already familiar with Vim motions. Feel free to skip them or try to figure out how they work and learn a bit of Vim in practice. However, this tutorial isn’t focused on teaching Vim motions — consider this your warning.

Some neovim plugins will also appear at certain points, these will have some unique keybindings too.

A Point in 2D

Let’s look at a typical example. In modern Python, we can create a simple class to represent a 2D point.

Let’s run this code

The output we get is the representation of our point object.

<__main__.Point object at 0x716bd3f19e90>

This is a bit opaque — it only tells us the type and the memory address of the object.

__repr__ for transparency

Often, we’re more interested in the contents of an object rather than its identity.

To get a more human-friendly representation, we can override the __repr__ method of the class. This will make the output clearer and more informative, making our lives a bit easier when debugging.

The new representation of a point will look like this:

Point(x=3, y=4)

equality

For fun, we can check if two points are equal

p = Point(3, 4)
q = Point(5, 12)
(p == q) is False

This seems intuitive so far, but what if we compare two point objects that we expect to be equal because they have the same coordinates?

(Point(3,4) == Point(3,4)) is False

By default, equality in python is based on the identity of the objects, so a point will only be equal to itself

(p == p) is True

To override this behaviour, let’s create a custom __eq__ method, where we compare the coordinates

(Point(3,4) == Point(3,4)) is True

The output is now more in line to what we would expect when working with Points.

Point(x=3, y=4) Point(x=5, y=12) Point(x=3, y=4)
True False True

Eliminating boilerplate with @dataclass

A fairly simple example so far, yet we have written quite a bit of code already, only to achieve what could be considered a reasonable default.

This is where dataclasses come into play. We can get the same exact functionality with far fewer lines of code.

The @dataclass decorator automatically generates the __init__, __repr__, and __eq__ methods for the class, making our code much more concise while maintaining the same functionality.

By running the code again, we get the same output as before

Point(3, 4) Point(5, 12) Point(3, 4)
True False False

This simplification also clarifies the intent of the code. By not explicitly defining the __init__, __repr__, and __eq__ methods, we’re expressing that we want the default, most straightforward implementation of those.

Adding more methods works as usual. For example, a p-norm could look like this

output

The Pythagorean triples are now complete with this output

...
5.0
13.0

Weightlifting with nn.Modules

A simple model

Now let’s look at a simple custom model definition using PyTorch.

For a usual model we need 4 core elements:

  • subclassing nn.Module
  • calling super().__init__ at the beginning of __init__
  • defining the building blocks of the model (usually other nn.Module subclasses)
  • implementing a forward method (this is where we define how the data flows and transforms through our model)

about init

The __init__ looks simple, but there is in fact a lot happening behind the scenes. The super().__init__() call is crucial here, it sets up the internal mechanisms of nn.Module.

From the pytorch documentation:

an __init__() call to the parent class must be made before assignment on the child.

so in our example, super().__init__() has to happen before self.linear = ...

We now have but one choice: We must face the long dark of PyTorch

nn.Module has its own __setattr__ implementation, so the self.linear = ... and self.relu = ... assignments aren’t as simple as they appear.

To understand how the internal state management works, we can take a look at the source code of nn.Module.__init__

The comment here captures the main idea:

"""
Calls `super().__setattr__('a', a)` instead of the typical `self.a = a`
to avoid `Module.__setattr__` overhead. Module's `__setattr__` has special
handling for `parameter`s, `submodule`s, and `buffer`s but simply calls into
`super().__setattr__` for all other attributes.
"""

internal dictionaries for modules

Overly simplified, if we are assigning an nn.Module as an attribute like this:

self.linear = nn.Linear(...)

the special handling in __setattr__ will see that the value we are assigning has the type nn.Linear (which is a subclass of nn.Module), and will store it in the _modules dictionary

modules['linear'] = nn.Linear(...)

parameters and buffers are also handled similarly.

This internal state management enables for more ergonomic code, for example we can just call .to on a Module, and it will move all of its parameters (including the parameters of its submodules) to the given device

__repr__

nn.Module has its own __repr__ implementation, that makes use of the _modules dictionary to include the building blocks of the model in its printed representation.

SingleLayer(
  (linear): Linear(in_features=1764, out_features=512, bias=True)
  (relu): ReLU()
)

A more flexible NeuralNetwork

Let’s create another NeuralNetwork, now with a few hyperparameters: we create parameters for the input-, hidden- and output layer dimensions as well as the number of hidden layers.

And we can create the actual building blocks of the network: here they are Linear layers with ReLU activations.

We pass these to Sequential to get our net, and we also define the loss to be the Mean Squared Error.

The forward is really simple: we just pass the input x to net, and calculate the loss, which is just the mean squared error between the prediction y and the target.

In the __main__, we simply instantiate the model, and print it.

Let’s see the output of this:

NeuralNetwork(
  (net): Sequential(
    (0): Linear(in_features=4096, out_features=42, bias=True)
    (1): ReLU()
    (2): Linear(in_features=42, out_features=42, bias=True)
    (3): ReLU()
    (4): Linear(in_features=42, out_features=42, bias=True)
    (5): ReLU()
    (6): Linear(in_features=42, out_features=1, bias=True)
    (7): ReLU()
  )
  (loss): MSELoss()
)

first attempt at @dataclass

There’s quite a bit of boilerplate in this class, we can again try to eliminate some of it with the use of @dataclass.

To keep the generated __init__ intact, we can create net and loss in the __post_init__ method. __post_init__ is automatically called at the end of the @dataclass generated __init__.

Next, we can also define the attributes created in __post_init__ as fields in the class. By marking these with init=False, they become excluded from the __init__ method, and we become responsible for initializing them elsewhere, such as in the __post_init__.

The code is now a bit less repetitive, and the intent is also more clear. As an added benefit, we get a separation of the attributes that are passed to the class, and the ones that are created during the initialization of the class: net, and loss can be thought of as owned by the NeuralNetwork.

Not so superb init

The code in its current state suffers from a few errors, so let’s try to fix the first one

AttributeError: cannot assign module before Module.__init__() call

We forgot to call super().__init__. This has to happen somewhere before assigning a submodule like self.net = Sequential(*layers).

Naively, we could try to place this in the __post_init__ too:

    def __post_init__(self):
        super().__init__()

this seems to work:

NeuralNetwork(x_dim=4096, h_dim=42, y_dim=1, num_hidden_layers=3)

But we have to keep in mind that super().__init__ is supposed to be called at the beginning of initialization, otherwise we might run into more issues.

For example, let’s try to do a small refactor, since MSELoss doesn’t depend on any of the input arguments of __init__, we could pass a default_factory to it’s field.

And now we get the same error again

AttributeError: cannot assign module before Module.__init__() call

We need a super().__init__ call before assigning PyTorch-specific values like an instance of MSELoss (which itself is a subclass of nn.Module).

But the default_factory is called in __init__, and there is an attempt to store this as an attribute like so

self.loss = MSELoss()

super().__init__ is only called after this, since we placed it in the __post_init__ method.

Ideally super().__init__() should happen at the very beginning of our __init__ , but we would also like to rely on the @dataclass generated __init__, so we don’t have to write the tedious self.x = x lines.

Unfortunately there is no corresponding __pre_init__ function that would get called before __init__, when using dataclasses. It is possible to hack something together with __new__ (since this happens before __init__), but let’s instead look into another library to help us.

switching libraries: attrs

The attrs library has the same simplifying power as the built-in dataclasses, and in some cases it’s even more powerful (https://www.attrs.org/en/stable/why.html#why-not)

One of those cases is if we want to hook into initialization (https://www.attrs.org/en/stable/init.html#hooking-yourself-into-initialization)

So first let’s install the library:

Next we can replace the previous imports from dataclasses with attrs

And we can use __attrs_pre_init__ to call torch.nn.Module.__init__.

Now this seems to work, there are no more visible errors in the class

NeuralNetwork(x_dim=4096, h_dim=42, y_dim=1, num_hidden_layers=3)

Hashenanigans

Let’s try to do the very first step of a usual training setup, which would be simply passing model.parameters() to an optimizer.

The error

TypeError: unhashable type: 'NeuralNetwork'

indicates that nn.Module is expected to be hashable, which it is not right now.

The cause for this lies in how attrs works. It seems like somehow using @define makes the class unhashable.

There is a great explanation on why that happens in the documentation: (https://www.attrs.org/en/stable/hashing.html)[https://www.attrs.org/en/stable/hashing.html]

Basically, if there is a custom __eq__ implementation in our class (and by default attrs creates this), python will make it unhashable.

An easy fix is adding @define(eq=False), so that attrs doesn’t create an __eq__ for us (and in our case comparing models directly doesn’t really make sense anyways).

There’s no error anymore, and finally our hybridization is complete.

One last potential bug

You should not initialize an optional attribute with None.

This is a subtle bug, one that is not specific to our current setup, it could also occur in a usual custom pytorch model. However, because of the current setup, there is a bigger chance that you would face this.

Since if we first assign e.g. loss=None, this will result in a super().__setattr__('loss', None) call from nn.Module.__setattr__, setting an attribute on the super class.

If we then try to reassign another value to loss, even though internally this new value get’s registered in the _modules dictionary, when trying to access model.loss later, we will always get back None, because object.loss has priority in python over nn.Module.__getattr__.

So don’t do this:

...
    loss: MSELoss = field(default=None)

    def __attrs_post_init__(self):
        ...
        self.loss = MSELoss()

if __name__ == '__main__':
    ...
    print(nn.loss is None)

And even though we re-assigned self.loss, we would still get the output

...
True

Drawbacks of the hybrid class

Compared to the vanilla version of our model, the new pattern does introduce a bit of overhead, and you need to understand some intrinsics of both nn.Module and dataclasses.

Benefits of this pattern

A nicer debugging experience

(print-)debugging a model will now produce a dataclass flavoured output, and you can also more easily fine tune this by using field(repr=False) to hide uninteresting parts of the model when printing.

Division between passed arguments and object managed arguments

This mostly goes for the intent of the code. Usually there is a clear division between arguments that are provided as input to the __init__ (but we might want to save these for later use, so they also become part of the object), and things that are created inside __post_init__ (these are completely “owned” by the object).

For complex models this could make the code a bit easier to understand.

End of Tutorial