An Underrated Technique to Define More Elegant Python Classes
Ever wondered why we never explicitly invoke the forward() method in PyTorch?
A Python class always defines some methods in a class that an object can invoke.
For instance, consider we want to evaluate the following quadratic:
One way is to define a method that accepts the input and returns the value of the quadratic, as shown below:
Of course, there is nothing wrong with this approach.
But there is one smart and elegant way of doing this in Python.
Instead of explicitly invoking a method, we can define the __call__()
magic method.
This magic method allows you to define the behavior of the class object when it is invoked like a function (like this: object())
.
Let’s rename the evaluate()
method to __call__()
.
As a result, we can now invoke the class object directly instead of invoking a method explicitly.
This can have many advantages. For instance:
It allows us to implement objects that can be used in a flexible and intuitive way.
It allows us to use a class object in contexts where a callable object is expected — using a class as a decorator, for instance.
In fact, unknown to many, this happens all the time when we build deep learning models with PyTorch.
For instance, consider this simple PyTorch class:
Here, the forward()
method defines the forward pass of the model.
Now tell me something.
When was the last time you explicitly invoked the model.forward()
method to run the forward pass?
I am sure you would have never done that.
Instead, PyTorch users always write just model()
to run the forward pass as if the model
object was a Python function:
But the variable model
is a class object, right? It is not a function. This can be verified below:
Then how are we able to invoke it like a function — model()
?
As you may have already guessed, this becomes possible because all PyTorch classes implicitly declare the __call__()
method themselves.
Within that __call__()
method, they invoke the user-defined forward pass.
A simplified version of this is depicted below:
PyTorch itself adds the
__call__()
method.The
__call__()
method invokes the user-definedforward()
method.
This way, Python gets to know that the model
object can be invoked like a function — model()
.
In fact, we can verify that we will get the same output no matter which way we run the forward pass:
Cool Pythonistic stuff, isn’t it?
These things revolve around good and elegant object-oriented programming practices.
We covered such advanced OOP stuff in detail in a recent deep dive here if you wish to level up your OOP skills: Object-Oriented Programming with Python for Data Scientists.
👉 Over to you: What are some other cool Python OOP tricks?
👉 If you liked this post, don’t forget to leave a like ❤️. It helps more people discover this newsletter on Substack and tells me that you appreciate reading these daily insights. The button is located towards the bottom of this email.
Thanks for reading!
Latest full articles
If you’re not a full subscriber, here’s what you missed last month:
Why Bagging is So Ridiculously Effective At Variance Reduction?
Sklearn Models are Not Deployment Friendly! Supercharge Them With Tensor Computations.
Deploy, Version Control, and Manage ML Models Right From Your Jupyter Notebook with Modelbit
Model Compression: A Critical Step Towards Efficient Machine Learning.
Gaussian Mixture Models (GMMs): The Flexible Twin of KMeans.
To receive all full articles and support the Daily Dose of Data Science, consider subscribing:
👉 Tell the world what makes this newsletter special for you by leaving a review here :)
👉 If you love reading this newsletter, feel free to share it with friends!