PyTorch has always been my go-to library for building any deep learning model.
However, one thing I particularly dislike about PyTorch is manually writing its long training loops, which go as follows:
For every epoch:
For every batch:
Run the forward pass
Calculate the loss
Compute the gradients
Run backpropagation
Compute epoch accuracy
Print the accuracy, loss, etc.
That’s too much work and code, isn’t it?
Skorch immensely simplifies training neural networks with PyTorch.
Skorch (Sklearn + PyTorch) is an open-source library that provides full Scikit-learn compatibility to PyTorch.
This means we can train PyTorch models in a way similar to Scikit-learn, using functions such as fit(), predict(), score(), etc.
Isn’t that cool?
Let’s see how to use it!
First, we define our PyTorch neural network as we usually would (no change here):
Make sure you have installed Skorch: pip install skorch
.
As we are creating a classifier, we import and create an object of Skorch’s NeuralNetClassifier
class.
There’s a class for regression models as well:
NeuralNetRegressor
.
The first argument is the PyTorch model class (
MyClassifier
).Next, we specify training hyperparameters like learning rate, batch size, etc.
We also specify the optimizer and loss function as a parameter.
Done!
Now, we can directly invoke fit()
method to train the model as follows:
As shown above, Skorch automatically prints all training metrics for us.
What’s more, we can also call the predict()
and score()
methods to generate predictions and output accuracy, respectively.
Isn’t that simple, cool, and elegant?
PyTorch lightning is yet another library, which further supercharges the whole PyTorch framework, and comes with built-in plug-and-play support for mixed precision training, multi-GPU or TPU training, logging, profiling, reducing boilerplate code, and more.
We discussed it here recently: A Detailed and Beginner-Friendly Introduction to PyTorch Lightning: The Supercharged PyTorch.
👉 Over to you: Are you aware of any other utility libraries to simplify model training? Let me know :)
Whenever you are ready, here’s one more way I can help you:
Every week, I publish 1-2 in-depth deep dives (typically 20+ mins long). Here are some of the latest ones that you will surely like:
[FREE] A Beginner-friendly and Comprehensive Deep Dive on Vector Databases.
You Are Probably Building Inconsistent Classification Models Without Even Realizing
Why Sklearn’s Logistic Regression Has no Learning Rate Hyperparameter?
PyTorch Models Are Not Deployment-Friendly! Supercharge Them With TorchScript.
Federated Learning: A Critical Step Towards Privacy-Preserving Machine Learning.
You Cannot Build Large Data Projects Until You Learn Data Version Control!
To receive all full articles and support the Daily Dose of Data Science, consider subscribing:
👉 If you love reading this newsletter, feel free to share it with friends!
👉 Tell the world what makes this newsletter special for you by leaving a review here :)