Walkthrough of Graph Attention Network (GAT) with Visual Implementation
Last Updated on November 5, 2023 by Editorial Team
Author(s): David R. Winer
Originally published on Towards AI.
Understanding Graph Neural Networks (GNNs) is increasingly relevant as transformers continue tackling graph problems like from the Open Graph Benchmark. Even if natural language is all a graph needs, GNNs remain a fruitful source of inspiration for future methods.
In this post, Iβm going to walk through the implementation of a vanilla GNN layer, and then show the modifications for a vanilla graph attention layer as described by the ICLR paper titled Graph Attention Networks.
Initially, imagine we have a graph of text documents represented as a directed acyclic graph (DAG). Document 0 has an edge to documents 1, 2, and 3 so there are 1s in the 0th row for those columns.
To do the visualized implementation, Iβm going to use Cerbrec Graphbook, a visual AI modeling tool. See my other post for more information about how to understand the visual representations in Graphbook.
We also have some node features for each document. I put each document into BERT as a single [5] 1D array of Texts to produce a [5, 768] shape of embeddings in the Pooler Output.
For instructional purposes, Iβll take just the first 8 dimensions of the BERT output as the node features so that we can follow the data shapes more easily. Now, we have our adjacency matrix and our node features.
GNN Layer
The general formula for a GNN Layer is that for each node, we take all the neighbors of each node and sum over the features multiplied by a weight matrix, and then pass through an activation function. Iβve created a blank block with this formula as the title and passed it in the Adj matrix and the node features, and I will implement that formula inside the block.
When we implement this formula, we donβt want to actually run a loop. If we can completely vectorize this, then any training and inference with GPUs will be much faster because multiplication can be a single computing step. So, instead, we tile (i.e., broadcast) the node features into a 3D shape, so we had [5, 8] shape of node features, now weβll have a [5, 5, 8] shape where each cell in the 0th dimension is a repeat of the node features. We can think of the last dimension as βneighborβ features now. Each node has a set of 5 possible neighbors.
We cannot directly broadcast the node features from [5, 8] to a [5, 5, 8] shape. Instead, we have to first broadcast to [25, 8] because when broadcasting, every dimension in shape has to be greater than or equal to the original dimension. So thatβs why we get the 5 and 8 parts of the shape (get_sub_arrays) and then multiply the first to get 25, then concatenate them all together. Finally, we reshape the resulting [25, 8] back to [5, 5, 8], and we can indeed verify right in Graphbook that each set of node features in the final 2 dimensions is identical.
Next, we want to also broadcast the adjacency matrix to the same shape. This means that for every 1 in the adjacency matrix on row i and col j, thereβs a row of 1.0s of num_feats at dimension [i, j] . So in this adjacency, row 0 has a 1 in the 1st, 2nd, and 3rd cols, so there is a row of num_feats 1.0s in rows 1, 2, and 3 in the 0th cell (i.e., [0, 1:4, :]).
The implementation is quite simple here, just parse the adjacency matrix to decimal and broadcast from [5, 5] shape to [5, 5, 8]. Now we can element wise multiply this adjacency mask by our tiled node neighbor features.
We also want to include a self-loop to our adjacency matrix, so that when we sum over the neighbor features, we also include that nodeβs own node feature.
After doing an element-wise multiply (and including the self-loop) we get the neighbor features for each node and zeros for the nodes that arenβt connected by an edge (are not neighbors). For the 0th node, that includes features for nodes 0 through 3. For the 3rd node, that includes the 3rd and 4th nodes.
Next, weβll reshape to [25, 8] so that every neighbor feature is its own row, and pass that through a parameterized linear layer with your desired hidden size. Here I chose 32 and saved as a global constant so it can be reused. The output of the linear layer will be [25, hidden_size]. Simply reshape that output, create shape [5, 5, hidden_size] and now weβre finally ready for the sum part of the formula!
We sum over the middle dimension (dimension index 1) so that we are summing the neighbor features for each node. The result is a [5, hidden_size] set of node embeddings that have gone through 1 layer. Simply chain these layers together and you have a GNN network, and follow guides from https://www.youtube.com/@Graphbook for how to train.
Graph Attention Layer
From the paper, the secret sauce behind the graph attention layer is the attention coefficient, given in the above formula. In essence, weβre concatenating the node embeddings that are in an edge and running through another linear layer, before applying softmax.
These attention coefficients are then used to compute a linear combination of the features corresponding to the original node features.
What we need to do is tile each nodeβs features for each neighbor, and then concatenate that with the nodeβs neighbor features.
The secret sauce is to get the nodeβs features tiled for each neighbor. To do that, we swap the 0 and 1 dimensions of the tiled node features prior to mask.
The result is still a [5, 5, 8] shaped array, but now every row in [i, :, :] is the same, and corresponds to node iβs feature. Now, we can use the element-wide multiply to create the node features repeating only when they contain a neighbor. Finally, we concatenate that with the neighbor features as we created them for GNN and produce the concatenated features.
Weβre almost there! Now that we have the concatenated features, we can put these through a linear layer. We need to reshape back to [5, 5, hidden_size] so that we can softmax over the middle dimension and produce our attention coefficients.
Now that we have our attention coefficients with shape [5, 5, hidden_size], which is essentially one embedding per graph edge for our n node graph. The paper says these should be transposed (dimensions swapped) so Iβve gone back and done that prior to ReLU, and now I softmax over the last dimension so that theyβre normalized per dimension index along the hidden size dimension. Weβre going to multiply these coefficients by the original node embeddings. Recall, the original node embeddings were shaped [5, 5, 8], where 8 came arbitrarily from slicing the first 8 features off BERTβs encodings of our text documents.
Multiply [5, hidden_size, 5] shape by [5, 5, 8] shape produces [5, hidden_size, 8] shape. Then we sum over the hidden_size dimension to finally output [5, 8], matching our input shape. We could also now put this through a non-linearity like another ReLU, and then chain this layer multiple times.
Conclusion
Thus far, weβve gone over a visual implementation of single GNN layers and GAT layers. You can find the project in this github repo. In the paper, they also explain how they extend the method for multi-head attention. Let me know in the comments if youβd like me to cover this part as well, or if thereβs anything else youβd like me to cover using Graphbook.
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI