Join thousands of AI enthusiasts and experts at the Learn AI Community.

Publication

Latest

Why Accuracy Is Not A Good Metric For Imbalanced Data

Last Updated on August 11, 2022 by Editorial Team

Author(s): Rafay Qayyum

Originally published on Towards AI the World’s Leading AI and Technology News and Media Company. If you are building an AI-related product or service, we invite you to consider becoming an AI sponsor. At Towards AI, we help scale AI and technology startups. Let us help you unleash your technology to the masses.

Classification, In Machine Learning, is a supervised learning concept where data points are classified into different classes. For example, determining if an email is “spam” or “not spam” and determining the blood type of a patient.

Machine Learning Classification is generally divided into three categories:

  • Binary Classification
  • Multi-class Classification
  • Multi-label Classification

What are Imbalanced classes or data?

Imbalanced data refers to a problem where the distribution of examples across the known classes is biased (One class has more instances than the other). For example, One class may have 10000 instances while the other class has just 100 instances.

Class with majority instances is weighed more than the class with minority instances — Google

Data Imbalance can range from small to huge differences in the number of instances of the classes. Small data imbalances such as 4:1, 10:1, etc., won’t harm your model much, but as the data imbalance starts to increase to 1000:1 and 5000: it can create problems for your machine learning model.

The class (or classes) in an imbalanced classification problem that has many instances is known as the Majority Class(es).

The class (or classes) in an imbalanced classification problem that has few instances is known as the Minority Class(es).

Why Imbalanced Classes can cause problems?

When working with imbalanced data, The minority class is our interest most of the time. Like when detecting “spam” emails, they number quite a few compared to “not spam” emails. So, the machine learning algorithms favor the larger class and sometimes even ignore the smaller class if the data is highly imbalanced.

Machine learning algorithms are designed to learn from the training data to minimize the loss and maximize accuracy. Let’s see how a machine learning algorithm works with highly imbalanced data.

An Example

Consider this example where there are 100 instances of Class “A” and 9900 instances of Class “B”.

x, y = make_classification(n_samples=10000, weights=[0.99], flip_y=0)

The count plot of the dataset can be created with the seaborn library

np.unique(y,return_counts=True)
y=np.where(y==0,'A','B')
sns.countplot(x=y)
count plot of the dataset.
xtrain, xtest, ytrain, ytest = train_test_split(x, y, test_size=0.20, random_state=42)
print(np.unique(ytrain,return_counts=True))
print(np.unique(ytest,return_counts=True))

After splitting the dataset into training and test data using train_test_split with a test size of 20%, we are left with 7919 training examples for Class “A” and 81 training examples for Class “B”. Testing examples are 1981 for Class “A” and 19 for Class “B”.

Let’s first train a Logistic Regression model with our training data.

lr=LogisticRegression()
lr.fit(xtrain,ytrain)
lr.score(xtest,ytest)

Now, if we check the accuracy of the model using the scoring method, it is 0.992. 99.2% Accuracy? It’s performing great, right? Let’s check the confusion matrix.

pred_lr=lr.predict(xtest)
print(confusion_matrix(ytest,pred_lr))
Confusion matrix for Logistic Regression

Although Class “A” has an accuracy of 100%, only 3 out of 19 test examples were classified correctly. It must be a mistake, right?

Let’s use Random Forest Classifier on the same dataset and check what’s happening.

rfc=RandomForestClassifier()
rfc.fit(xtrain,ytrain)
rfc.score(xtest,ytest)

The accuracy score is 0.991 this time, but what did we learn last time? The real results hide behind the accuracy. Let’s check the confusion matrix for Random Forest Classifier’s Predictions.

pred_rfc=rfc.predict(xtest)
print(confusion_matrix(ytest,pred_rfc))
Confusion matrix for Random Forest Classifier

Only 1 out of 1981 testing examples for Class “A” was classified wrong, but only 2 out of 19 testing examples for Class “B” were classified correctly.

What are our machine learning models doing here?

As we have discussed before, machine learning models try to maximize accuracy, that’s what is happening here. Since the instances of Class “A” make up 99% of the data, machine learning models learn to classify them correctly and ignore or do not learn much about Class “B” because classifying all of the data to class “A” will get it 99% accuracy.

You can match the accuracy of these models just by writing 1 statement in python. Shocked?

pred=['A']*len(ytest)

This statement creates a list of length 2000 (since total testing data is 2000 or 20% for 10000) and fills it with “A”. Since 99% of the sample is just A class, so we get the accuracy of 99% using the accuracy score.

accuracy_score(ytest,pred)
Confusion matrix for the “pred” list

How can you handle an imbalanced dataset?

There are many ways through which you can handle an imbalanced dataset. Some require you to have field knowledge others use different algorithms to increase the instances of minority class (Over-sampling) and to decrease the instances of majority class (Under-sampling).


Why Accuracy Is Not A Good Metric For Imbalanced Data was originally published in Towards AI on Medium, where people are continuing the conversation by highlighting and responding to this story.

Join thousands of data leaders on the AI newsletter. It’s free, we don’t spam, and we never share your email address. Keep up to date with the latest work in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓