Almost all real-world ML models gradually degrade in performance due to a drift in feature distribution:
It is a serious problem because we trained the model on one distribution, but it is being used to generate predictions on another distribution in production.
Thus, it is critical to detect drift early so that models continue to work well.
Today, I want to share one intuitive technique I often use to determine which features are drifting in my dataset.
Let’s begin!
Consider we have two versions of the dataset — the old version (one on which the model was trained) and the current version (one on which the model is generating predictions):
The core idea is to assess if there are any distributional dissimilarities between these two versions.
So here’s what we do:
Append a
label=1
column to the old dataset.Append a
label=0
column to the current dataset.
Now, merge these two datasets and train a supervised learning classification model on the combined dataset:
The choice of the classification model could be a bit arbitrary, but it should be ensured that it is possible to reliably determine feature importance.
Thus, I personally prefer a random forest classifier because it has an inherent mechanism to determine feature importance:
That said, it is not necessary to use a random forest.
Techniques like shuffle feature importance (which we discussed here) illustrated below on top of a classification model can be used as well:
Moving on…
If the feature importance values suggest that there are features with high feature importance, this means that those features are drifting.
Why?
This is because if some features can reliably distinguish between the two versions of the dataset, then it is pretty likely that their distribution corresponding to label=1
and label=0
(conditional distribution) are varying.
If there are distributional differences, the model will capture them.
If there are no distributional differences, the model will struggle to distinguish between the classes.
This idea makes intuitive sense as well.
At this point, one question that many have is that…
Why can’t we just monitor the model accuracy to determine drift?
Of course, we can do that as long as we have the true labels for the current version or have a way to compare the model performance on the old version and the new version.
But in many cases, the true output predictions on production data are never immediately available.
Instead, they always take some time.
For instance, when I was working on a transactional fraud detection model at Mastercard, a cardholder’s issuer bank may take up to 45-50 days to send the fraud label for the transactions that went through Mastercard’s network.
This is a lot of time, isn’t it?
Thus, one must rely on some relevant feedback from the system to determine whether the model’s performance is dropping or not.
Using this “proxy-labeling technique” discussed above is something I have found to be immensely useful.
Of course, this is not the only technique to determine drift.
I once published this three-part guide, which I am sure you will find useful to read next:
Part 1: Covariate Shift Is Way More Problematic Than Most People Think
Part 2: How to Detect Multivariate Covariate Shift in Machine Learning Models?
Part 3: How to Interpret Reconstruction Loss While Detecting Multivariate Covariate Shift?
Also, if you have ever been intimidated by deployment, this is a complete beginner-friendly guide on deploying, version controling, and managing ML models from your Jupyter Notebook with Modelbit.
👉 Over to you: What are some other ways you use to determine drift?
Thanks for reading!
Whenever you are ready, here’s one more way I can help you:
Every week, I publish 1-2 in-depth deep dives (typically 20+ mins long). Here are some of the latest ones that you will surely like:
[FREE] A Beginner-friendly and Comprehensive Deep Dive on Vector Databases.
A Detailed and Beginner-Friendly Introduction to PyTorch Lightning: The Supercharged PyTorch
You Are Probably Building Inconsistent Classification Models Without Even Realizing
Why Sklearn’s Logistic Regression Has no Learning Rate Hyperparameter?
PyTorch Models Are Not Deployment-Friendly! Supercharge Them With TorchScript.
Federated Learning: A Critical Step Towards Privacy-Preserving Machine Learning.
You Cannot Build Large Data Projects Until You Learn Data Version Control!
To receive all full articles and support the Daily Dose of Data Science, consider subscribing:
👉 If you love reading this newsletter, feel free to share it with friends!
👉 Tell the world what makes this newsletter special for you by leaving a review here :)
Great share avi bro. I think we can do the same by using some statistical methods
like kolmogorov-simnrov test helps us to detect whether those two distributions are same or not.
Interesting technique. I have two questions:
- how do you quantitatively say that there are high-importance features, and low importance features after training the random forest? Do you use a threshold or do you cluster the values?
- given a real-time ML system, at which frequency do you use proxy-labeling techniques? Curious to know if you have thought about an architecture about this.