Transfer Learning vs. Fine-tuning vs. Multitask Learning vs. Federated Learning
Four critical model training paradigms that you MUST know for real-world ML modelling.
Most ML models are trained independently without any interaction with other models.
However, in the realm of real-world ML, there are many powerful learning techniques that rely on model interactions to improve performance.
The following animation neatly summarizes four such well-adopted and must-know training methodologies:
Let’s discuss them today.
#1) Transfer learning
This is extremely useful when:
The task of interest has less data.
But a related task has abundant data.
This is how it works:
Train a neural network model (base model) on the related task.
Replace the last few layers on the base model with new layers.
Train the network on the task of interest, but don’t update the weights of the unreplaced layers during backpropagation.
By training a model on the related task first, we can capture the core patterns of the task of interest.
Later, we can adjust the last few layers to capture task-specific behavior.
Another idea which is somewhat along these lines is knowledge distillation, which involves the “transfer” of knowledge. We discussed it here if you are interested in learning about it.
Transfer learning is commonly used in many computer vision tasks.
#2) Fine-tuning
Fine-tuning involves updating the weights of some or all layers of the pre-trained model to adapt it to the new task.
The idea may appear similar to transfer learning, but in fine-tuning, we typically do not replace the last few layers of the pre-trained network.
Instead, the pretrained model itself is adjusted to the new data.
#3) Multi-task learning
As the name suggests, a model is trained to perform multiple tasks simultaneously.
The model shares knowledge across tasks, aiming to improve generalization and performance on each task.
It can help in scenarios where tasks are related, or they can benefit from shared representations.
In fact, the motive for multi-task learning is not just to improve generalization.
We can also save compute power during training by having a shared layer and task-specific segments.
Imagine training two models independently on related tasks.
Now compare it to having a network with shared layers and then task-specific branches.
Option 2 will typically result in:
Better generalization across all tasks.
Less memory utilization to store model weights.
Less resource utilization during training.
This and this are two of the best survey papers I have ever read on multi-task learning.
#4) Federated learning
This is another pretty cool technique for training ML models.
Simply put, federated learning is a decentralized approach to machine learning. Here, the training data remains on the devices (e.g., smartphones) of users.
Instead of sending data to a central server, models are sent to devices, trained locally, and only model updates are gathered and sent back to the server.
It is particularly useful to enhance privacy and security. What’s more, it also reduces the need for centralized data collection.
The keyboard of our smartphone is a great example of this.
Federated learning allows our smartphone’s keyboard to learn and adapt to our typing habits. This happens without transmitting sensitive keystrokes or personal data to a central server.
The model, which predicts our next word or suggests auto-corrections, is sent to our device, and the device itself fine-tunes the model based on our input.
Over time, the model becomes personalized to our typing style while preserving our data privacy and security.
Do note that as the model is trained on small devices, it also means that these models must be extremely lightweight yet powerful enough to be useful.
Model compression techniques are prevalent in such use cases, which we discussed in detail here.
Pretty cool, isn’t it?
👉 Over to you: What are some other ML training methodologies that I have missed here?
👉 If you liked this post, don’t forget to leave a like ❤️. It helps more people discover this newsletter on Substack and tells me that you appreciate reading these daily insights.
The button is located towards the bottom of this email.
Thanks for reading!
Latest full articles
If you’re not a full subscriber, here’s what you missed last month:
You Cannot Build Large Data Projects Until You Learn Data Version Control!
Why Bagging is So Ridiculously Effective At Variance Reduction?
Sklearn Models are Not Deployment Friendly! Supercharge Them With Tensor Computations.
Deploy, Version Control, and Manage ML Models Right From Your Jupyter Notebook with Modelbit
Gaussian Mixture Models (GMMs): The Flexible Twin of KMeans.
To receive all full articles and support the Daily Dose of Data Science, consider subscribing:
👉 Tell the world what makes this newsletter special for you by leaving a review here :)
👉 If you love reading this newsletter, feel free to share it with friends!
These are such nice diagrams! Would you mind sharing how you made them, especially the flowing gradients?
Great work! Can you please assist me with obtaining guidance on implementing these methods in Python (preferably using PyTorch)?