Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Read by thought-leaders and decision-makers around the world. Phone Number: +1-650-246-9381 Email: pub@towardsai.net
228 Park Avenue South New York, NY 10003 United States
Website: Publisher: https://towardsai.net/#publisher Diversity Policy: https://towardsai.net/about Ethics Policy: https://towardsai.net/about Masthead: https://towardsai.net/about
Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Founders: Roberto Iriondo, , Job Title: Co-founder and Advisor Works for: Towards AI, Inc. Follow Roberto: X, LinkedIn, GitHub, Google Scholar, Towards AI Profile, Medium, ML@CMU, FreeCodeCamp, Crunchbase, Bloomberg, Roberto Iriondo, Generative AI Lab, Generative AI Lab VeloxTrend Ultrarix Capital Partners Denis Piffaretti, Job Title: Co-founder Works for: Towards AI, Inc. Louie Peters, Job Title: Co-founder Works for: Towards AI, Inc. Louis-François Bouchard, Job Title: Co-founder Works for: Towards AI, Inc. Cover:
Towards AI Cover
Logo:
Towards AI Logo
Areas Served: Worldwide Alternate Name: Towards AI, Inc. Alternate Name: Towards AI Co. Alternate Name: towards ai Alternate Name: towardsai Alternate Name: towards.ai Alternate Name: tai Alternate Name: toward ai Alternate Name: toward.ai Alternate Name: Towards AI, Inc. Alternate Name: towardsai.net Alternate Name: pub.towardsai.net
5 stars – based on 497 reviews

Frequently Used, Contextual References

TODO: Remember to copy unique IDs whenever it needs used. i.e., URL: 304b2e42315e

Resources

Our 15 AI experts built the most comprehensive, practical, 90+ lesson courses to master AI Engineering - we have pathways for any experience at Towards AI Academy. Cohorts still open - use COHORT10 for 10% off.

Publication

XAI: Graph Neural Networks
Latest   Machine Learning

XAI: Graph Neural Networks

Last Updated on September 12, 2025 by Editorial Team

Author(s): Kalpan Dharamshi

Originally published on Towards AI.

What are Graph Neural Networks?

Graph Neural Networks (GNNs) combine the representational power of neural networks with the complex structure of graphs. Deep neural networks, particularly those leveraging a multi-head attention framework, excel at processing images and textual data. Graphs, on the other hand, can format data as a network of nodes (entities) and edges (relationships between entities) to effectively capture intricate dependencies. GNNs leverage the structure of graphs to iteratively aggregate the data for a node from its neighbors. The iterative aggregation of data is achieved through a message-passing mechanism. It allows the network to learn node embeddings that encode both a node’s features and its position within the graph topology. These embeddings can help GNN uncover hidden patterns, make predictions on data, and perform node and graph classification.

How does GNN work behind the scenes?

In this section, we will understand the underlying math operations.

XAI: Graph Neural Networks
Linear equation for a neural network

For simplicity, we have considered a linear equation for the neural network model. In the equation, y is the output, x is the input, w is the weights of the neural network, and b is the bias of the network. Traditional neural networks use a linear equation to transform inputs. For higher-dimensional data, this linear transformation is performed using matrix operations, where variables are represented as matrices and operations like multiplication and addition become matrix multiplication and matrix addition.

In the case of graphs, the added complexity is from all the connected neighbors of a node. We can aggregate all the information from a node’s neighbors using a message-passing mechanism, which simplifies the equation to a summation for a node with n neighbors.

Aggregation of data of all the neighboring n nodes.

The data in the equation needs to be normalized to accommodate nodes with fewer neighbors. The aggregated information for nodes with a higher number of neighbors will be larger than that of nodes with a lower number. Therefore, the degree of each node and its neighbor acts as a normalization factor.

Normalized representation for each node.

Data Analysis

We have chosen the Zachary’s Karate Club data available in the torch geometric package for analysis and explanation. The classical dataset is widely used as a benchmark for graphical analysis and serves as an excellent starting point for anyone learning and experimenting with GNN.

We install torch_geometric and pandas packages and import necessary classes and functions for our experimentation purposes.

!pip install torch_geometric
!pip install pandas
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.datasets import KarateClub
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.nn import GCNConv, GCN2Conv, GATConv

from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt

import pandas as pd
dataset = KarateClub()
data = dataset[0]

G = to_networkx(data, to_undirected=True)
plt.figure(figsize=(12,12))
plt.axis('off')
nx.draw_networkx(G,
pos=nx.spring_layout(G, seed=0),
with_labels=True,
node_size=800,
node_color=data.y,
cmap="hsv",
vmin=-2,
vmax=3,
width=0.8,
edge_color="grey",
font_size=14
)
plt.show()
Node classification for Karate Club Dataset

The dataset represents a social network of 34 members at the club. Each node represents a member, and the edges represent the social interactions between the members outside of the club. The interactions have created a grouping of the members into different sections within the club. The core task is to predict the member section of the club or to classify each node.

Graph Characteristics

  • Number of Nodes (Members): 34
  • Number of Edges (Social Interactions): 156 (undirected and unweighted)
  • Node Features: Each node has 34 features.
  • Node Labels: There are four classes available for classification.

GNN Implementation

We use a Graph Attention Network implementation for our GNN. The attention mechanism is used to identify the most relevant features and neighboring nodes for prediction.

class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.gat_conv_first = GATConv(in_channels, hidden_channels)
self.gat_conv_second = GATConv(hidden_channels, out_channels)

def forward(self, x, edge_index):
x = F.relu(self.gat_conv_first(x,edge_index))
return self.gat_conv_second(x, edge_index)

Model Setup

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GAT(in_channels=dataset.num_node_features, hidden_channels=5, out_channels=dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
loss_fn = torch.nn.CrossEntropyLoss()

Model Training and Results

def accuracy(pred_y, y):
return (pred_y == y).sum() / len(y)

# 4. Training loop
def train():
model.train()
optimizer.zero_grad()
z = model(data.x, data.edge_index)
loss = loss_fn(z[data.train_mask], data.y[data.train_mask])
acc = accuracy(z.argmax(dim=1), data.y)
loss.backward()
optimizer.step()
return loss,acc


# Train the model
loss_list = []
acc_list = []
for epoch in range(1, 101):
loss,acc = train()
loss_list.append(loss.item())
acc_list.append(acc.item())
print(f'Epoch : {epoch} Loss : {loss}, Accuracy: {acc}')


df = pd.DataFrame({'Epochs': range(1,101), 'Loss': loss_list, 'Accuracy': acc_list})
df.plot(x='Epochs', y=['Loss', 'Accuracy'])
plt.title('Line Plot of Loss and Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Loss and Accuracy')
plt.show()
The loss and accuracy curves

The training results indicate a gradual decline in loss and an increase in accuracy. The process achieves a training accuracy of more than 90%, which is sufficient for our experimentation purposes.

XAI

Finally, we reach a point where we can start the explanation process of our GNN.

The torch geometric package provides us with GNNExplainer to explain and visualize the process of message aggregation at a node, and also capture the important features used for node classification.

explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
task_level='node',
return_type='log_probs',
mode='multiclass_classification'
)
)

# Explain the prediction for a single node (e.g., index 5)
explanation = explainer(x=data.x, edge_index=data.edge_index, index=5)
# Visualize the explanation
explanation.visualize_graph(path='explanation.png')
#explanation.get_explanation_subgraph().visualize_graph(path='sub_graph.png')
explanation.visualize_feature_importance(path='feature_importance.png',top_k=10)
Inspect node 5 and its neighbors

If we observe closely, some of the edges in the above explanation are faint and others are prominent. The prominent edges indicate high weightage or influence, and faint edges imply low weightage or influence for the node under consideration.

Top 10 features leveraged for node classification

The above graph highlights the significant features for the node classification. The feature importance provides the weightage of each attribute used for classification or prediction purposes.

The feature labels (e.g., ‘5’, ‘10’, ‘0’) on the y-axis correspond to the indices of the input features, not the node IDs in the graph. As the bar chart shows, features 5, 10, and 0 have the highest importance scores for this node classification.

XAI libraries like SHAP, can also generate similar feature importance metrics for machine learning models.

Why do we need XAI?

Imagine we are working on a massive network and are required to explain a GNN prediction of a node classification. Without XAI capabilities, it would be challenging to reason which neighboring nodes and features have influenced the decision. GNN Explainers help us bridge the gap and provide us with the required insight.

XAI capabilities help troubleshoot or debug the model, build trust and interpretability, and meet audit and compliance requirements as well.

In essence, XAI serves as the bridge between a GNN’s black-box output and a human’s need for understanding, ensuring these powerful models can be used responsibly and effectively.

The entire code for the experiment can be found on GitHub.

Hope you liked the article and learned something new today !!

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI


Take our 90+ lesson From Beginner to Advanced LLM Developer Certification: From choosing a project to deploying a working product this is the most comprehensive and practical LLM course out there!

Towards AI has published Building LLMs for Production—our 470+ page guide to mastering LLMs with practical projects and expert insights!


Discover Your Dream AI Career at Towards AI Jobs

Towards AI has built a jobs board tailored specifically to Machine Learning and Data Science Jobs and Skills. Our software searches for live AI jobs each hour, labels and categorises them and makes them easily searchable. Explore over 40,000 live jobs today with Towards AI Jobs!

Note: Content contains the views of the contributing authors and not Towards AI.