Text Classification with CNNs: Are They Dead After Transformers?
Last Updated on April 29, 2025 by Editorial Team
Author(s): S Aishwarya
Originally published on Towards AI.
In todayβs AI-driven world, transformer models like BERT, RoBERTa, and GPT dominate the field of NLP. From powering chatbots to summarizing documents, transformers seem to be the gold standard for almost every text-related task.
Given their success, you might wonder:
Are CNNs (Convolutional Neural Networks) still relevant for text classification β or have transformers made them obsolete?
Thatβs the tech question weβre exploring today.
In this guide, weβll compare two approaches for text classification:
CNNs trained from scratch
Transformer models fine-tuned with HuggingFace
CNNs, traditionally famous for image recognition, might sound like an odd choice for text classification in 2025. But donβt dismiss them too quickly! CNNs can capture local patterns (like phrases and n-grams) exceptionally well, and with less computational overhead than transformers. They offer fast training, simpler architectures, and surprisingly competitive performance on many text classification tasks.
Building a CNN-based text classification model is straightforward and lightweight. With just a few layers, you can train a deep learning model to classify text efficiently:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Conv1D, GlobalMaxPooling1D, Dense
model = Sequential([
Embedding(input_dim=vocab_size, output_dim=128, input_length=max_length),
Conv1D(filters=128, kernel_size=5, activation='relu'),
GlobalMaxPooling1D(),
Dense(64, activation='relu'),
Dense(num_classes, activation='softmax')
])
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=10, batch_size=32, validation_data=(X_val, y_val))
But thereβs more to CNNs than just speed and simplicity. In this article, weβll explore why CNNs can still be a powerful tool for text classification, how they work under the hood, and whether they can still compete with modern transformers in real-world NLP tasks.
📦 Dataset: AG News Classification Dataset
Weβll be using the widely recognized AG News dataset from the AGβs corpus of news articles, which contains over 120,000 news articles categorized into four major classes:
Dataset LINK: AG News Dataset on Kaggle
The categories are:
- World: International news and global events
- Sports: Sports news and events
- Business: Business, finance, and economic news
- Sci/Tech: Scientific and technological developments
Each entry includes:
- Title: The headline of the news article
- Description: A short summary or description of the article
This dataset is ideal for text classification tasks because itβs balanced, clean, and covers diverse topics, making it a great benchmark for comparing CNNs and Transformer-based models.
🔧 Approach 1: CNN Trained from Scratch
📥 Step 1: Import Libraries
Essential packages for data handling, preprocessing, and modeling.
import pandas as pd, numpy as np
import matplotlib.pyplot as plt
import re, nltk
from sklearn.model_selection import train_test_split
from nltk.corpus import stopwords
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Conv1D, GlobalMaxPooling1D, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
📂 Step 2: Load Dataset
Read the AG News dataset.
df = pd.read_csv("train.csv", header=None)
df.columns = ['Class Index', 'Title', 'Description']
🏷οΈ Step 3: Map Labels
Map numerical class index to readable categories (optional for clarity).
label_mapping = {1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tech"}
df['Category'] = df['Class Index'].map(label_mapping)
🧹 Step 4: Clean the Text
Preprocess by removing special characters and stopwords.
def clean_text(text):
text = str(text).lower()
text = re.sub(r'<.*?>', '', text)
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r'\d+', '', text)
text = " ".join([word for word in text.split() if word not in stop_words])
return text
df['text'] = df['Title'] + " " + df['Description']
df['text'] = df['text'].apply(clean_text)
🔠 Step 5: Tokenize & Pad
Convert text to padded sequences.
tokenizer = Tokenizer(num_words=50000, oov_token="<OOV>")
tokenizer.fit_on_texts(df['text'])
sequences = tokenizer.texts_to_sequences(df['text'])
word_index = tokenizer.word_index
vocab_size = len(word_index) + 1
max_length = 200
padded_sequences = pad_sequences(sequences, maxlen=max_length, padding='post')
🧪 Step 6: Train-Test Split
Split into training and validation sets.
df['Class Index'] = pd.to_numeric(df['Class Index'], errors='coerce')
df.dropna(subset=['Class Index'], inplace=True)
df.reset_index(drop=True, inplace=True)
y = df['Class Index'].astype(int).values - 1
X = padded_sequences[:len(df)]
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
🧠 Step 7: Build the CNN Model
Create a simple yet powerful CNN for text classification.
model = Sequential([
Embedding(input_dim=vocab_size, output_dim=128, input_length=max_length),
Conv1D(128, 5, activation='relu'),
GlobalMaxPooling1D(),
Dense(64, activation='relu'),
Dropout(0.5),
Dense(4, activation='softmax')
])
⚙οΈ Step 8: Compile and Train
Compile and fit the CNN model.
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
early_stop = EarlyStopping(monitor='val_loss', patience=2, restore_best_weights=True)
history = model.fit(X_train, y_train, epochs=5, batch_size=64, validation_data=(X_val, y_val), callbacks=[early_stop])
📈 Step 9: Visualize Accuracy
Plot training vs validation accuracy.
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title("Training vs Validation Accuracy (CNN)")
plt.show()
✅ Step 10: Evaluate Model
Check the final validation accuracy.
loss, acc = model.evaluate(X_val, y_val)
print(f"\n✅ Final Validation Accuracy: {acc:.4f}")
Output:
🤖 Approach 2: DistilBERT Fine-Tuning with HuggingFace
Why build a CNN from scratch when you can fine-tune a lightweight Transformer like DistilBERT?
DistilBERT is a distilled version of BERT β smaller, faster, cheaper β yet retains 95% of its language understanding power, making it ideal for resource-constrained environments.
🛠οΈ Step-by-Step Implementation
📦 1. Install Required Libraries
pip install transformers datasets tensorflow
📥 2. Load and Prepare the Dataset
import pandas as pd
from sklearn.model_selection import train_test_split
df = pd.read_csv("train.csv", header=None)
df.columns = ['Class Index', 'Title', 'Description']
df['text'] = df['Title'] + " " + df['Description']
train_texts, val_texts, train_labels, val_labels = train_test_split(
df['text'].tolist(), (df['Class Index'] - 1).tolist(), test_size=0.2, random_state=42)
✂οΈ 3. Tokenization with DistilBERT Tokenizer
from transformers import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=512)
🧾 4. Prepare TensorFlow Datasets
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices((
dict(train_encodings),
train_labels
)).shuffle(1000).batch(16)
val_dataset = tf.data.Dataset.from_tensor_slices((
dict(val_encodings),
val_labels
)).batch(16)
🧠 5. Load and Fine-Tune DistilBERT
from transformers import TFDistilBertForSequenceClassification, AdamW
model = TFDistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=4)
model.compile(optimizer=AdamW(learning_rate=2e-5),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_dataset, validation_data=val_dataset, epochs=3)
✅ 6. Evaluate Model Performance
loss, accuracy = model.evaluate(val_dataset)
print(f"\n✅ Final Validation Accuracy: {accuracy:.4f}")
Output:
✅ Final Validation Accuracy: 0.9452
📊 Comparative Analysis: CNN vs DistilBERT
🎯 Conclusion
Transformers like DistilBERT outperform CNNs in overall accuracy due to their ability to deeply understand the context and semantics of language. However, CNNs are still surprisingly strong contenders β especially when speed, simplicity, and fewer computational resources are your priorities.
If youβre building lightweight apps, educational models, or need fast training without access to heavy hardware, CNNs absolutely hold their ground.
⚖οΈ When to Use CNN vs DistilBERT
Use CNN if:
- You want simpler models for fast training
- You have limited computational resources
- You want quick prototyping for text tasks
Use DistilBERT if:
- You need higher accuracy
- Your datasets are large, noisy, or diverse in language.
- You can afford slightly more compute power
💡Final Thoughts
CNNs are not dead for text classification! While Transformers dominate benchmarks, CNNs offer faster, cheaper, and easier solutions β especially when data and resources are limited. Choosing the right model depends on your needs, not just the latest trend!
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI