AI-Enhanced Graph Analytics in Data Science and a Step-by-Step Guide on Building a GNN
Last Updated on November 1, 2024 by Editorial Team
Author(s): Bhavesh Agone
Originally published on Towards AI.
Graph analytics has emerged as one of the most crucial tools in the data science sphere to understand complex relationships and structures within data. From social networks to biological systems, graphs can represent all kinds of systems. With the help of analysis of graphs, hidden patterns or insights may be derived from them. With AI in general and deep learning in particular, the field of graph analytics has been experiencing significant developments. This blog post focuses on AI-enhanced graph analytics, its principles and applications, and also has more complex code examples illustrating what can be done.
What is Graph Analytics?
Graph analytics is the study of graphs, which are mathematical structures that can model pairwise relations between objects. Graphs are comprised of nodes, also known as vertices, and edges, which are links that connect them. Applications of graph analytics include:
1. Social network analysis: understanding relationships and influence among individuals.
2. Recommendation Systems: It recommends the products or service based on behavior.
3. Fraud Detection: Detecting fraud by analyzing transaction networks.
4. Biological Network Analysis: Proteo-protein interactions and gene regulatory networks.
Introduction to Graph Neural Networks (GNNs)
GNNs are a class of neural networks designed for inferencing on data described in terms of graphs. Thus, they combine the advantages of the relational structure encoded by graphs with feature-learning capability of neural networks. The basic GNN model follows these steps:
- Step 1: Node embedding initialization: Initialize each node with a feature vector.
- Message Passing Nodes send messages to neighbors.
- Aggregation: Aggregate received messages to update node embeddings.
- Prediction: Use the fine-tuned embeddings for downstream tasks, such as node classification or link prediction.
Why Graph Neural Networks (GNNs)?
Traditional graph algorithms, from Dijkstraβs shortest path to PageRank, tell us a lot but just do not cut it against large, complex graphs. AI, especially Graph Neural Networks (GNNs), has emerged as a potentially powerful technique for enhancing graph analytics by learning and inferring patterns in graph-structured data with deep learning.
Advantages of GNNs:
- Scalability: GNNs can handle large graphs with millions of nodes and edges.
- Flexibility: They could work with graphs of very varied types, including also directed and undirected, weighted graphs and multi-relational graphs.
- Learning Ability: The GNNs can directly learn complex patterns and features from the data, therefore not requiring handcrafted features.
- They allow end-to-end training by which feature extraction and prediction are optimized simultaneously.
Basic three types of Graph Neural Networks (GNNs):
- Graph Convolutional Network (GCN): Adopting layer-wise propagation over the features from the one-hop neighbors of a node, it is well designed for large-scale graphs while performing semi-supervised tasks.
- Graph Attention Network: It applies an attention model to adaptively select importance weights among neighboring nodes, thus allowing adaptive information gathering at each node and improving representation ability with multi-head attentions.
- GraphSAGE: This sampling method aggregates features from neighbors by using some graph convolutions, endowed the network with the ability to be inductively learned in order to handle large or dynamic graphs by various aggregation functions.
Building a GNN: A Step-by-Step Guide
We will create a GNN for node classification on a sample dataset with Python and popular libraries such as PyTorch and PyTorch Geometric. With increasing complexity, weβll introduce more advanced concepts step-by-step.
Importing Libraries
import torch
import torch.nn.functional as F
from torch_geometric.datasets import CitationFull
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
Loading the Dataset
Weβll use the # Load the CiteSeer dataset a popular benchmark for graph learning tasks.
dataset = CitationFull(root='/tmp/CiteSeer', name='CiteSeer')
data = dataset[0]
Create masks for train, validation, and test sets and Randomly split data into train, val, and test.
# Create masks for train, validation, and test sets
num_classes = dataset.num_classes
num_nodes = data.num_nodes
# Randomly split data into train, val, and test
indices = np.arange(num_nodes)
np.random.shuffle(indices)
train_size = int(0.6 * num_nodes)
val_size = int(0.2 * num_nodes)
test_size = num_nodes - train_size - val_size
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[indices[:train_size]] = True
val_mask[indices[train_size:train_size + val_size]] = True
test_mask[indices[train_size + val_size:]] = True
Training Classes
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GCN, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GAT, self).__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=8, concat=True)
self.conv2 = GATConv(hidden_channels * 8, out_channels, heads=1, concat=False)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
Training & Testing Models
def train(model, data, optimizer):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[train_mask], data.y[train_mask])
loss.backward()
optimizer.step()
return loss.item()
def test(model, data, mask):
model.eval()
with torch.no_grad():
pred = model(data.x, data.edge_index).argmax(dim=1)
return accuracy_score(data.y[mask].cpu(), pred[mask].cpu()), pred
Hyperparameters
in_channels = dataset.num_node_features
hidden_channels = 64
out_channels = dataset.num_classes
epochs = 200
lr = 0.01
Initialize models
gcn = GCN(in_channels, hidden_channels, out_channels).to(device)
gat = GAT(in_channels, hidden_channels, out_channels).to(device)
sage = GraphSAGE(in_channels, hidden_channels, out_channels).to(device)
# Move data to device
data = data.to(device)
# Initialize optimizers
gcn_optimizer = torch.optim.Adam(gcn.parameters(), lr=lr)
gat_optimizer = torch.optim.Adam(gat.parameters(), lr=lr)
sage_optimizer = torch.optim.Adam(sage.parameters(), lr=lr)
Training
# Training
gcn_train_acc = []
gat_train_acc = []
sage_train_acc = []
gcn_loss_history = []
gat_loss_history = []
sage_loss_history = []
for epoch in range(epochs):
gcn_loss = train(gcn, data, gcn_optimizer)
gat_loss = train(gat, data, gat_optimizer)
sage_loss = train(sage, data, sage_optimizer)
# Collect training accuracies and losses
gcn_train_acc.append(test(gcn, data, train_mask)[0])
gat_train_acc.append(test(gat, data, train_mask)[0])
sage_train_acc.append(test(sage, data, train_mask)[0])
gcn_loss_history.append(gcn_loss)
gat_loss_history.append(gat_loss)
sage_loss_history.append(sage_loss)
if epoch % 20 == 0:
print(f'Epoch {epoch}, GCN Loss: {gcn_loss:.4f}, GAT Loss: {gat_loss:.4f}, SAGE Loss: {sage_loss:.4f}')
Testing the models
# Test models
gcn_test_acc, gcn_pred = test(gcn, data, test_mask)
gat_test_acc, gat_pred = test(gat, data, test_mask)
sage_test_acc, sage_pred = test(sage, data, test_mask)
print(f'Test Accuracy: GCN: {gcn_test_acc:.4f}, GAT: {gat_test_acc:.4f}, SAGE: {sage_test_acc:.4f}')
Visualize
# Training accuracy plot
plt.figure(figsize=(8, 6))
plt.plot(range(epochs), gcn_train_acc, label='GCN Train Accuracy', color='blue')
plt.plot(range(epochs), gat_train_acc, label='GAT Train Accuracy', color='orange')
plt.plot(range(epochs), sage_train_acc, label='GraphSAGE Train Accuracy', color='green')
plt.xlabel('Epochs')
plt.ylabel('Training Accuracy')
plt.title('Training Accuracy for Different GNN Models')
plt.legend()
plt.grid()
plt.show()
# Loss over epochs plot
plt.figure(figsize=(8, 6))
plt.plot(range(epochs), gcn_loss_history, label='GCN Loss', color='blue')
plt.plot(range(epochs), gat_loss_history, label='GAT Loss', color='orange')
plt.plot(range(epochs), sage_loss_history, label='GraphSAGE Loss', color='green')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Over Epochs for Different GNN Models')
plt.legend()
plt.grid()
plt.show()
# Test accuracy comparison
plt.figure(figsize=(8, 6))
models = ['GCN', 'GAT', 'GraphSAGE']
test_accuracies = [gcn_test_acc, gat_test_acc, sage_test_acc]
plt.bar(models, test_accuracies, color=['blue', 'orange', 'green'])
plt.ylabel('Test Accuracy')
plt.title('Test Accuracy Comparison')
plt.ylim(0, 1)
plt.grid()
plt.show()
# Confusion matrices
def plot_confusion_matrix(y_true, y_pred, model_name):
cm = confusion_matrix(y_true.cpu(), y_pred.cpu())
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
xticklabels=np.arange(num_classes), yticklabels=np.arange(num_classes))
plt.title(f'Confusion Matrix for {model_name}')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()
# Plot confusion matrices for each model
plot_confusion_matrix(data.y[test_mask], gcn_pred[test_mask], 'GCN')
plot_confusion_matrix(data.y[test_mask], gat_pred[test_mask], 'GAT')
plot_confusion_matrix(data.y[test_mask], sage_pred[test_mask], 'GraphSAGE')
Inference
Test Accuracy: GCN: 0.9397, GAT: 0.9338, SAGE: 0.9409
Conclusion
AI-enhanced graph analytics is a mighty approach for extracting insights out of complex graph-structured data. Using GNNs and advanced techniques, such as GAT, for example, data scientists would be able to tackle very wide-ranging problems. Since the graph theory combined with AI offers new possibilities, innovation, and discovery areas in multiple domains, try to explore and experiment even more with these models for unlocking their full potential.
Please feel free to share your thoughts in the comment section. Your suggestions are always welcome.
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI