
XAI: Graph Neural Networks
Last Updated on September 12, 2025 by Editorial Team
Author(s): Kalpan Dharamshi
Originally published on Towards AI.
What are Graph Neural Networks?
Graph Neural Networks (GNNs) combine the representational power of neural networks with the complex structure of graphs. Deep neural networks, particularly those leveraging a multi-head attention framework, excel at processing images and textual data. Graphs, on the other hand, can format data as a network of nodes (entities) and edges (relationships between entities) to effectively capture intricate dependencies. GNNs leverage the structure of graphs to iteratively aggregate the data for a node from its neighbors. The iterative aggregation of data is achieved through a message-passing mechanism. It allows the network to learn node embeddings that encode both a node’s features and its position within the graph topology. These embeddings can help GNN uncover hidden patterns, make predictions on data, and perform node and graph classification.
How does GNN work behind the scenes?
In this section, we will understand the underlying math operations.

For simplicity, we have considered a linear equation for the neural network model. In the equation, y is the output, x is the input, w is the weights of the neural network, and b is the bias of the network. Traditional neural networks use a linear equation to transform inputs. For higher-dimensional data, this linear transformation is performed using matrix operations, where variables are represented as matrices and operations like multiplication and addition become matrix multiplication and matrix addition.
In the case of graphs, the added complexity is from all the connected neighbors of a node. We can aggregate all the information from a node’s neighbors using a message-passing mechanism, which simplifies the equation to a summation for a node with n neighbors.

The data in the equation needs to be normalized to accommodate nodes with fewer neighbors. The aggregated information for nodes with a higher number of neighbors will be larger than that of nodes with a lower number. Therefore, the degree of each node and its neighbor acts as a normalization factor.

Data Analysis
We have chosen the Zachary’s Karate Club data available in the torch geometric package for analysis and explanation. The classical dataset is widely used as a benchmark for graphical analysis and serves as an excellent starting point for anyone learning and experimenting with GNN.
We install torch_geometric and pandas packages and import necessary classes and functions for our experimentation purposes.
!pip install torch_geometric
!pip install pandas
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.datasets import KarateClub
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.nn import GCNConv, GCN2Conv, GATConv
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
dataset = KarateClub()
data = dataset[0]
G = to_networkx(data, to_undirected=True)
plt.figure(figsize=(12,12))
plt.axis('off')
nx.draw_networkx(G,
pos=nx.spring_layout(G, seed=0),
with_labels=True,
node_size=800,
node_color=data.y,
cmap="hsv",
vmin=-2,
vmax=3,
width=0.8,
edge_color="grey",
font_size=14
)
plt.show()

The dataset represents a social network of 34 members at the club. Each node represents a member, and the edges represent the social interactions between the members outside of the club. The interactions have created a grouping of the members into different sections within the club. The core task is to predict the member section of the club or to classify each node.
Graph Characteristics
- Number of Nodes (Members): 34
- Number of Edges (Social Interactions): 156 (undirected and unweighted)
- Node Features: Each node has 34 features.
- Node Labels: There are four classes available for classification.
GNN Implementation
We use a Graph Attention Network implementation for our GNN. The attention mechanism is used to identify the most relevant features and neighboring nodes for prediction.
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.gat_conv_first = GATConv(in_channels, hidden_channels)
self.gat_conv_second = GATConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = F.relu(self.gat_conv_first(x,edge_index))
return self.gat_conv_second(x, edge_index)
Model Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GAT(in_channels=dataset.num_node_features, hidden_channels=5, out_channels=dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
loss_fn = torch.nn.CrossEntropyLoss()
Model Training and Results
def accuracy(pred_y, y):
return (pred_y == y).sum() / len(y)
# 4. Training loop
def train():
model.train()
optimizer.zero_grad()
z = model(data.x, data.edge_index)
loss = loss_fn(z[data.train_mask], data.y[data.train_mask])
acc = accuracy(z.argmax(dim=1), data.y)
loss.backward()
optimizer.step()
return loss,acc
# Train the model
loss_list = []
acc_list = []
for epoch in range(1, 101):
loss,acc = train()
loss_list.append(loss.item())
acc_list.append(acc.item())
print(f'Epoch : {epoch} Loss : {loss}, Accuracy: {acc}')
df = pd.DataFrame({'Epochs': range(1,101), 'Loss': loss_list, 'Accuracy': acc_list})
df.plot(x='Epochs', y=['Loss', 'Accuracy'])
plt.title('Line Plot of Loss and Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Loss and Accuracy')
plt.show()

The training results indicate a gradual decline in loss and an increase in accuracy. The process achieves a training accuracy of more than 90%, which is sufficient for our experimentation purposes.
XAI
Finally, we reach a point where we can start the explanation process of our GNN.
The torch geometric package provides us with GNNExplainer to explain and visualize the process of message aggregation at a node, and also capture the important features used for node classification.
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
task_level='node',
return_type='log_probs',
mode='multiclass_classification'
)
)
# Explain the prediction for a single node (e.g., index 5)
explanation = explainer(x=data.x, edge_index=data.edge_index, index=5)
# Visualize the explanation
explanation.visualize_graph(path='explanation.png')
#explanation.get_explanation_subgraph().visualize_graph(path='sub_graph.png')
explanation.visualize_feature_importance(path='feature_importance.png',top_k=10)

If we observe closely, some of the edges in the above explanation are faint and others are prominent. The prominent edges indicate high weightage or influence, and faint edges imply low weightage or influence for the node under consideration.

The above graph highlights the significant features for the node classification. The feature importance provides the weightage of each attribute used for classification or prediction purposes.
The feature labels (e.g., ‘5’, ‘10’, ‘0’) on the y-axis correspond to the indices of the input features, not the node IDs in the graph. As the bar chart shows, features 5, 10, and 0 have the highest importance scores for this node classification.
XAI libraries like SHAP, can also generate similar feature importance metrics for machine learning models.
Why do we need XAI?
Imagine we are working on a massive network and are required to explain a GNN prediction of a node classification. Without XAI capabilities, it would be challenging to reason which neighboring nodes and features have influenced the decision. GNN Explainers help us bridge the gap and provide us with the required insight.
XAI capabilities help troubleshoot or debug the model, build trust and interpretability, and meet audit and compliance requirements as well.
In essence, XAI serves as the bridge between a GNN’s black-box output and a human’s need for understanding, ensuring these powerful models can be used responsibly and effectively.
The entire code for the experiment can be found on GitHub.
Hope you liked the article and learned something new today !!
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.
Published via Towards AI
Take our 90+ lesson From Beginner to Advanced LLM Developer Certification: From choosing a project to deploying a working product this is the most comprehensive and practical LLM course out there!
Towards AI has published Building LLMs for Production—our 470+ page guide to mastering LLMs with practical projects and expert insights!

Discover Your Dream AI Career at Towards AI Jobs
Towards AI has built a jobs board tailored specifically to Machine Learning and Data Science Jobs and Skills. Our software searches for live AI jobs each hour, labels and categorises them and makes them easily searchable. Explore over 40,000 live jobs today with Towards AI Jobs!
Note: Content contains the views of the contributing authors and not Towards AI.