Fine-Tuning Legal-BERT: LLMs For Automated Legal Text Classification
Author(s): Drewgelbard
Originally published on Towards AI.
Unlocking efficient legal document classification with NLP fine-tuning
Introduction
In todayβs fast-paced legal industry, professionals are inundated with an ever-growing volume of complex documents β from intricate contract provisions and merger agreements to regulatory compliance records and court filings. Manually sifting through these documents is not only labor-intensive and time-consuming, but also prone to human error and inconsistency. This inefficiency can lead to overlooked risks, non-compliance with regulations, and ultimately, financial damage for organizations.
The Challenge
Legal texts are uniquely challenging for natural language processing (NLP) due to their specialized vocabulary, intricate syntax, and the critical importance of context. Terms that appear similar in general language can have vastly different meanings in legal contexts. Therefore, generic NLP models often fall short when applied directly to legal documents.
The Solution
This is where fine-tuning specialized language models comes into play. By adapting models that are pre-trained on legal corpora, we can achieve higher accuracy and reliability in tasks like contract analysis, compliance monitoring, and legal document retrieval. In this article, we will delve into how Legal-BERT [5], a transformer-based model tailored for legal texts, can be fine-tuned to classify contract provisions using the LEDGAR dataset [4] β a comprehensive benchmark dataset specifically designed for the legal field.
What Youβll Learn
By the end of this tutorial, youβll have a complete roadmap for leveraging Legal-BERT to tackle legal text classification. Today will provide a guide on:
- Setting up your environment for NLP tasks involving legal documents.
- Understanding and preprocessing the LEDGAR dataset for optimal model performance.
- Performing exploratory data analysis to gain insights into the datasetβs structure.
- Fine-tuning Legal-BERT for multi-class classification of legal provisions.
- Evaluating the modelβs performance against established benchmarks.
- Discussing challenges and considerations specific to legal NLP applications.
Whether youβre a data scientist aiming to deepen your expertise in NLP or a machine learning engineer interested in domain-specific model fine-tuning, this tutorial will equip you with the tools and insights you need to get started.
Table of Contents
- Environment Setup
- Dataset Overview
- Preprocessing and Tokenization
- Exploratory Data Analysis (EDA)
- Training and Fine-Tuning
- Evaluating the Model
- Conclusion and Key Takeaways
Environment Setup
We will use the Hugging Face Transformers library, which offers pre-trained models and tools to fine-tune them. While not strictly necessary, using a GPU will speed up training significantly. If youβre using Google Colab, enable the GPU by going to Runtime > Change runtime type and selecting GPU.
First, install the necessary libraries:
!pip install transformers datasets torch scikit-learn
# Import necessary dependencies
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoModelForMaskedLM
from datasets import load_dataset, DatasetDict
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, classification_report
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
from transformers import AutoModelForSequenceClassification, DataCollatorForLanguageModeling, Trainer
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score, classification_report, precision_recall_curve
import seaborn as sns
import os
# Set device for GPU usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
Dataset Overview
The dataset chosen for this project is LEDGAR (Labeled EDGAR), part of the LexGLUE benchmark for legal language tasks [1, 3]. LEDGAR consists of contract provisions from publicly available SEC filings, also known as Exhibit 10 contracts, which are essential in legal fields. The dataset includes around 80,000 provisions labeled across 100 categories, from βAgreementsβ and βConfidentialityβ to βTerminationβ and βVestingβ [3].
LEDGAR presents a challenging dataset for NLP models due to its diverse terminology and context-specific labels. The provisions are divided into training, validation, and test sets, with 60,000 provisions for training, 10,000 for validation, and 10,000 for testing. For this tutorial, weβll download and prepare the dataset using Hugging Faceβs datasets library.
I recommend going to this link [4] to gain a better understanding of the dataset and LexGLUE benchmark.
# Load LEDGAR dataset
dataset = load_dataset('lex_glue', 'ledgar')
# Display dataset features
print(dataset['train'].features)
# Get label information
label_list = dataset['train'].features['label'].names
num_labels = len(label_list)
print(f"Number of labels: {num_labels}")
An example of what some of the train data looks like is as follows [4]:
{
"text": "Executive agrees to be employed with the Company, and the Company agrees to employ Executive, during the Term and on the terms and conditions set forth in this Agreement. Executive agrees during the term of this Agreement to devote substantially all of Executiveβs business time, efforts, skills and abilities to the performance of Executiveβs duties ...",
"label": "Employment"
}
Preprocessing and Tokenization
To fine-tune Legal-BERT effectively, we need to prepare the LEDGAR dataset with several preprocessing steps:
- Mapping Labels to Indices: Create mappings between label names and indices to ensure compatibility with PyTorch during training.
- Token Length Computation: calculate the token lengths of each text example. This helps us understand the data distribution and ensure that the maximum sequence length (set to 512 tokens) is appropriate for the dataset.
- Tokenize Texts: Each provision is tokenized using Legal-BERTβs tokenizer, which is designed to handle legal terminology.
- Truncate and Pad Sequences: truncate texts longer than the maximum length and pad shorter ones, setting a max length of 512 tokens. This ensures consistent input lengths.
# Create mappings from label names to indices and vice versa
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for idx, label in enumerate(label_list)}
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
# Token length computation function
def compute_token_lengths(example):
tokens = tokenizer.encode(example['text'], add_special_tokens=True)
example['num_tokens'] = len(tokens)
return example
# Apply token length computation to the dataset
dataset = dataset.map(compute_token_lengths)
def preprocess_data(examples):
# Tokenize the texts
return tokenizer(
examples['text'],
truncation=True, # Truncate texts longer than max_length
padding='max_length', # Pad texts shorter than max_length
max_length=512
)
# Apply the preprocessing function to the dataset
encoded_dataset = dataset.map(preprocess_data, batched=True)
# Set the format of the dataset to PyTorch tensors
encoded_dataset.set_format(
type='torch',
columns=['input_ids', 'attention_mask', 'label']
)
encoded_dataset
Exploratory Data Analysis (EDA)
EDA is an essential step in any machine learning workflow, especially when working with large and complex datasets like LEDGAR. By examining the dataβs structure, distribution, and key characteristics, we can make informed decisions about preprocessing and model setup.
Token Length Distribution
Since our model (Legal-BERT) has a maximum input token length of 512, understanding the token length distribution is important to assess whether the data fits within this constraint. Hereβs a histogram displaying the token length distribution for the training, validation, and test sets.
As we can see, while most of the provisions are under the 512-token limit, some exceed this limit. Truncating sequences longer than 512 tokens is necessary to maintain consistency across inputs. This aligns with the LexGLUE benchmark, allowing for results comparison to previous research [1].
Class Distribution
LEDGAR contains 100 different classes with a wide range of frequencies across categories. Below are bar charts showing the top 10 and bottom 10 classes in terms of frequency in the training set.
There is significant class imbalance in the dataset. Certain classes, such as βGoverning Lawsβ and βNotices,β have a high number of examples, while others, such as βBooksβ and βAssignments,β are less prevalent.
While addressing class imbalance can improve model performance, particularly on underrepresented classes, we have chosen to retain the original distribution in this case. This decision was made to ensure a fair comparison with the LexGLUE benchmark, which uses the unbalanced LEDGAR dataset [1]. In a production setting, however, techniques like class weighting or data augmentation could be used to balance the classes, enhancing the modelβs ability to generalize across all categories.
Now that we have completed the necessary pre-processing and EDA steps, letβs move into our fine tuning methodology.
Training and Fine-Tuning
Fine-tuning Legal-BERT on LEDGAR involves configuring the model for a multi-class classification task with 100 categories. We prioritize macro F1 and micro F1 scores as our primary evaluation metrics, consistent with the LexGLUE benchmark, which uses these metrics to assess model performance [1, 4]. Letβs dive in on how to set up and train our model.
Model and Training Configuration
We load Legal-BERT with a sequence classification head, specifying 100 output classes. Additionally, we define mappings for label names and indices to ensure proper label encoding.
# Load Legal-BERT with a classification head for 100 classes
model = AutoModelForSequenceClassification.from_pretrained(
"nlpaueb/legal-bert-base-uncased",
num_labels=num_labels, # Number of labels (100)
id2label=id2label, # Mapping from IDs to labels
label2id=label2id # Mapping from labels to IDs
)
model.to(device)
Custom Metrics for Evaluation
As previously stated, we will focus on the Macro F1 and Micro F1 performance metrics. Macro F1 averages the F1 scores of each class equally, providing insight into how well the model handles underrepresented categories, while Micro F1 aggregates the results across all classes, giving a holistic view of performance. By using these metrics, we ensure our results are directly comparable to the LexGLUE benchmark.
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=1)
# Calculate accuracy
accuracy = accuracy_score(labels, predictions)
# Calculate macro F1-score
macro_f1 = f1_score(labels, predictions, average='macro', zero_division=0)
# Calculate micro F1-score
micro_f1 = f1_score(labels, predictions, average='micro', zero_division=0)
return {
'accuracy': accuracy,
'macro_f1': macro_f1,
'micro_f1': micro_f1
}
Training Arguments
Next, we need to establish our the training arguments, including parameters such as batch size, number of epochs, learning rate, and evaluation strategy. The model saved is the βbest versionβ based on macro F1 score which ensures the model selected has the most balanced performance across all classes.
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=5,
per_device_train_batch_size=128,
per_device_eval_batch_size=128,
evaluation_strategy='epoch',
save_strategy='epoch',
learning_rate=2e-5,
logging_dir='./logs',
load_best_model_at_end=True,
metric_for_best_model='macro_f1',
greater_is_better=True,
fp16=True, # Enables mixed precision for faster training
logging_steps=100
)
Training the Model
Using Hugging Faceβs Trainer class, we train Legal-BERT on LEDGAR with the specified arguments and evaluation metrics. This setup simplifies the training process and provides automated evaluation after each epoch.
# Initialize the Trainer with the model, arguments, and training data
trainer = Trainer(
model=model,
args=training_args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset['validation'],
compute_metrics=compute_metrics
)
# Start fine-tuning
trainer.train()
Final Evaluation on the Test Set
After training is complete, we can evaluate the model on the test set to obtain final performance metrics. This allows us to assess how well the model generalizes to unseen data.
# Evaluate on the validation set
eval_results = trainer.evaluate()
# Extract and print macro and micro F1 scores
val_macro_f1 = eval_results.get('eval_macro_f1')
val_micro_f1 = eval_results.get('eval_micro_f1')
print("Validation Results:")
print(f"Validation Macro F1-score: {val_macro_f1:.4f}")
print(f"Validation Micro F1-score: {val_micro_f1:.4f}")
# Predict on the test set
test_results = trainer.predict(encoded_dataset['test'])
# Extract predictions and true labels
test_logits, test_labels = test_results.predictions, test_results.label_ids
test_predictions = np.argmax(test_logits, axis=1)
# Calculate test metrics
test_accuracy = accuracy_score(test_labels, test_predictions)
test_macro_f1 = f1_score(test_labels, test_predictions, average='macro', zero_division=0)
test_micro_f1 = f1_score(test_labels, test_predictions, average='micro', zero_division=0)
# Print test results
print("Test Results:")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test Macro F1-score: {test_macro_f1:.4f}")
print(f"Test Micro F1-score: {test_micro_f1:.4f}")
By focusing on Macro and Micro F1 scores, we can gain a balanced view of the modelβs performance across all classes, which is essential in imbalanced datasets like LEDGAR. This evaluation approach aligns with the LexGLUE benchmark [1, 4], allowing for comparison our results with industry-standard models for legal text classification. Below are the reported benchmark results compared to our model.
Although our modelβs performance is closely aligned with the benchmark results, slight differences may arise due to several factors. Variations in training configuration β such as hyperparameters, batch size, number of epochs, or learning rate β can affect the outcomes. Differences in data preprocessing, like handling rare classes or sequence truncation methods, might also contribute to the discrepancies. Lastly, the stochastic nature of training means that different random seeds can lead to variability in the results, explaining the minor performance gaps observed.
Conclusion and Key Takeaways
In this tutorial, we explored fine-tuning Legal-BERT on the LEDGAR dataset to classify legal contract provisions. By leveraging a domain-specific model, we were able to achieve strong performance in handling the nuanced language of legal documents, showcasing the power of fine-tuning in specialized NLP tasks.
Key Takeaways:
- Pre-trained legal models are invaluable: Legal-specific models like Legal-BERT are essential for accurately managing complex, specialized vocabulary and domain-specific nuances in legal texts.
- Fine-tuning for specialization: Adapting BERT models to specific tasks through fine-tuning enhances their classification performance and robustness, especially when dealing with imbalanced classes and varied categories in datasets like LEDGAR.
- Consistency with benchmarks: Aligning with established benchmarks, such as LexGLUE, allows for objective comparisons and insights into how your model stacks up against industry standards.
If you found this tutorial useful, follow me for my next article, where Iβll dive into an advanced technique: using Retrieval-Augmented Generation (RAG) combined with topic modeling to improve classification performance on complex datasets. This upcoming tutorial will build on ideas from this topic-alignment article by Ben McCloskey, exploring how topic alignment can further enhance model performance in NLP.
References
[1] Chalkidis, I., Jana, A., Hartung, D., et al. βLexGLUE: A Benchmark Dataset for Legal Language Understanding in English,β 2021. arXiv:2110.00976v4.
[2] McCloskey, B.J. βTopic Alignment for NLP Recommender Systems,β Medium, 2024. Available at: Medium Article.
[3] Tuggener, D., von DΓ€niken, P., Peetz, T., Cieliebak, M. βLEDGAR: A Large-Scale Multi-label Corpus for Text Classification of Legal Provisions in Contracts,β LREC 2020. (CC BY-SA 4.0) LREC Proceedings.
[4] Hugging Face. βLexGLUE Dataset on Hugging Face.β Hugging Face.
[5] Hugging Face. βLegal-BERT Model,β MIT License. Hugging Face Model.
Code
!pip install transformers datasets torch scikit-learn
# Import necessary dependencies
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoModelForMaskedLM
from datasets import load_dataset, DatasetDict
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, classification_report
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
from transformers import AutoModelForSequenceClassification, DataCollatorForLanguageModeling, Trainer
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score, classification_report, precision_recall_curve
import seaborn as sns
import os
# Set device for GPU usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load LEDGAR dataset
dataset = load_dataset('lex_glue', 'ledgar')
# Display dataset features
print(dataset['train'].features)
# Get label information
label_list = dataset['train'].features['label'].names
num_labels = len(label_list)
print(f"Number of labels: {num_labels}")
# Create mappings from label names to indices and vice versa
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for idx, label in enumerate(label_list)}
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
# Token length computation function
def compute_token_lengths(example):
tokens = tokenizer.encode(example['text'], add_special_tokens=True)
example['num_tokens'] = len(tokens)
return example
# Apply token length computation to the dataset
dataset = dataset.map(compute_token_lengths)
def preprocess_data(examples):
# Tokenize the texts
return tokenizer(
examples['text'],
truncation=True, # Truncate texts longer than max_length
padding='max_length', # Pad texts shorter than max_length
max_length=512
)
# Apply the preprocessing function to the dataset
encoded_dataset = dataset.map(preprocess_data, batched=True)
# Set the format of the dataset to PyTorch tensors
encoded_dataset.set_format(
type='torch',
columns=['input_ids', 'attention_mask', 'label']
)
encoded_dataset
# Plot token length distribution across the dataset splits
def plot_token_length_distribution(dataset, title='Token Length Distribution'):
plt.figure(figsize=(10,6))
for subset, color in zip(['train', 'validation', 'test'], ['blue', 'orange', 'green']):
lengths = dataset[subset]['num_tokens']
sns.histplot(lengths, bins=50, kde=True, label=subset.capitalize(), color=color, stat="frequency", alpha=0.5)
plt.title(title)
plt.xlabel('Number of Tokens')
plt.ylabel('Frequency')
plt.legend()
plt.show()
plot_token_length_distribution(dataset, title="Token Length Distribution (Full Dataset)")
# Function to plot class distribution for top and bottom classes
def plot_class_distribution(dataset, split='train', top_n=10):
label_counts = pd.Series(dataset[split]['label']).value_counts()
# Top 10 labels
top_labels = label_counts.head(top_n)
top_label_names = [label_list[i] for i in top_labels.index]
# Bottom 10 labels
bottom_labels = label_counts.tail(top_n)
bottom_label_names = [label_list[i] for i in bottom_labels.index]
# Plot the top 10 class distribution
plt.figure(figsize=(12,6))
sns.barplot(x=top_label_names, y=top_labels.values, palette='Blues_d')
plt.xticks(rotation=45, fontsize=12)
plt.title(f'Top {top_n} Class Distribution - {split}', fontsize=14)
plt.xlabel('Class Label', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.show()
# Plot the bottom 10 class distribution
plt.figure(figsize=(12,6))
sns.barplot(x=bottom_label_names, y=bottom_labels.values, palette='Reds_d')
plt.xticks(rotation=45, fontsize=12)
plt.title(f'Bottom {top_n} Class Distribution - {split}', fontsize=14)
plt.xlabel('Class Label', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.show()
# Call the function to plot for the train split
plot_class_distribution(dataset, 'train')
# Load Legal-BERT with a classification head for 100 classes
model = AutoModelForSequenceClassification.from_pretrained(
"nlpaueb/legal-bert-base-uncased",
num_labels=num_labels, # Number of labels (100)
id2label=id2label, # Mapping from IDs to labels
label2id=label2id # Mapping from labels to IDs
)
model.to(device)
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=1)
# Calculate accuracy
accuracy = accuracy_score(labels, predictions)
# Calculate macro F1-score
macro_f1 = f1_score(labels, predictions, average='macro', zero_division=0)
# Calculate micro F1-score
micro_f1 = f1_score(labels, predictions, average='micro', zero_division=0)
return {
'accuracy': accuracy,
'macro_f1': macro_f1,
'micro_f1': micro_f1
}
os.environ["WANDB_DISABLED"] = "true"
# Your Trainer arguments below
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=5,
per_device_train_batch_size=128,
per_device_eval_batch_size=128,
evaluation_strategy='epoch',
save_strategy='epoch',
learning_rate=2e-5,
logging_dir='./logs',
load_best_model_at_end=True,
metric_for_best_model='macro_f1',
greater_is_better=True,
fp16=True,
logging_steps=100,
report_to="none", # This disables all logging integrations like W&B
seed=42
)
trainer = Trainer(
model=model, # The pre-trained model
args=training_args, # Training arguments
train_dataset=encoded_dataset['train'], # Training dataset
eval_dataset=encoded_dataset['validation'], # Validation dataset
compute_metrics=compute_metrics # Evaluation metrics
)
# Start fine-tuning
trainer.train()
# Evaluate on the validation set
eval_results = trainer.evaluate()
# Extract and print macro and micro F1 scores
val_macro_f1 = eval_results.get('eval_macro_f1')
val_micro_f1 = eval_results.get('eval_micro_f1')
print("Validation Results:")
print(f"Validation Macro F1-score: {val_macro_f1:.4f}")
print(f"Validation Micro F1-score: {val_micro_f1:.4f}")
# Predict on the test set
test_results = trainer.predict(encoded_dataset['test'])
# Extract predictions and true labels
test_logits, test_labels = test_results.predictions, test_results.label_ids
test_predictions = np.argmax(test_logits, axis=1)
# Calculate test metrics
test_accuracy = accuracy_score(test_labels, test_predictions)
test_macro_f1 = f1_score(test_labels, test_predictions, average='macro', zero_division=0)
test_micro_f1 = f1_score(test_labels, test_predictions, average='micro', zero_division=0)
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI