Walkthrough of LoRA Fine-tuning on GPT and BERT with Visual Implementation
Last Updated on November 5, 2023 by Editorial Team
Author(s): David R. Winer
Originally published on Towards AI.
Fine-tuning, which is the learning or updating of weights in a transformer model, can be the delta between a model thatβs not ready for production to one that is robust enough to put in front of customers.
Back when BERT and GPT2 were first revolutionizing natural language processing (NLP), there was really only one playbook for fine-tuning. You had to be very careful with fine-tuning because of catastrophic forgetting. In essence, after you pre-trained your model, you didnβt want to overwrite the original weights so much that they forget previously learned connections. The practitionerβs secret was to dial the lower learning rate very low, freeze all but the last couple layers, and run through the downstream training data very carefully, with perhaps only one epoch for a large dataset. There are a few downsides to this approach. The weights per layer are still very large, and if you freeze certain layers, then your fine-tuning cannot affect those layers.
Fast forward to today, and now fine-tuning has a few new techniques, typically categorized together as Parameter Efficient Fine-tuning (PEFT) methods, with Low-Rank Adaption of Large Language Models (LoRA) as the primary example.
The central idea of LoRA is that you should keep the original pre-trained weights and add some new low-parameter weights to fine-tune instead. For example, if you have weights of size 768Β² = 589,824 parameters, then you pick some integer r and use two more weight matrices of size 768 * r. So if r = 4, then 768 * 4 + 4 * 768 = 6,144 parameters. Thatβs close to 1% of the parameters!
These low-parameter weights are added to your pre-trained weights as part of the compute graph. When training, you only update the new weights, so now the differentiate step only produces gradients for the new weights and the optimizer is only tracking optimizer states for the new weights. As a result, there is less computing during training and less memory needed for gradients and optimizer states. Since there are now so many parameters in todayβs models, LoRA is an important technique for getting fine-tuning to run on βregularβ sized machines and it speeds up training by needing less computing overall. The small downside is that the extra weights add to the overall memory needed at inference time, though just by a small percentage.
Implementing LoRA is an act of model surgery. In essence, you need to do a βlayer-ectomyβ, swapping out the original dense layers that you want to add LoRA to with the new setup. If youβve ever attempted a model surgery, you understand the challenges in the tooling to βoperateβ on your model and verify the operation went successfully. There are some other tutorials and examples in Keras, but I found them to be overly pre-scripted. This walkthrough is intended to be precise enough so that you can implement LoRA yourself just by looking at the visuals.
Visualization Details
To implement LoRA and do the surgery, we will work with a node graph visualization tool. Each block is an operation that takes the inputs on the left side and produces the data for the output variables on the right side. Links denote the passing of data from outputs to inputs, and circles on inputs mean the data is specified in place and is static.
Operations are either composite containing an βunboxβ icon, which then decomposes into a sub-graph whose inputs are the parentβs inputs and whose outputs are the parentβs outputs, or they are primitive, meaning they cannot be decomposed further and correspond to low-level tensor operations like from NumPy or TensorFlow. Colors indicate data type and patterns indicate the data shape. Blue means the dat type is an integer, whereas purple/pink means itβs a decimal data type. Solid links indicate that the data shape is scalar, whereas dots in the link indicate the number of dimensions of the array (the number of dots between the dashes). At the bottom of each graph is a table that characterizes the shape, type, and operation name of each variable that is carrying data in the model.
BERT LoRA
First, Iβll show LoRA in the BERT implementation, and then Iβll do the same for GPT.
First, Iβll start with what is LoRA. Initially, a LoRA layer starts with an input reflecting the hidden state or the original embeddings in the encoder, a hidden size (e.g., 768), and an integer r. We need to reshape the layer so that itβs 2D. If r = 4, and we have 2 inputs each padded to 10 tokens, then we reshape our [2 x 10 x 768] shape to [20 x 768].
There are 2 linear layers, called βAβ and βBβ. We feed our [20 x 768] input into the βAβ linear layer with hidden size 4 to produce a [20 x 4] shape.
Then we send the output into a βBβ linear layer with a hidden size matching the original vector size. This takes the [20 x 4] and multiply by [4 x 768], which brings it back to [20 x 768]. Then, after reshaping it back to [2 x 10 x 768].
We can feed this the same hidden state as our dense layer, and then element-wise add this to the original dense layer. Before the edit, the 3D Linear Layer output went to the place where the βaddβ block now links to. The LoRA Layer and βaddβ blocks were added during the surgery.
Then, we would do the same for the Value layer. In the implementation, the LoRA layer is only added to the Q and V projection matrices. These seem to be the most effective and efficient places to use LoRA, however, the authors also note that they leave the investigation of adapting other parameters for future work (e.g., adding LoRA to the biases or to layer normalization).
By looking at the crumb bar, you can see we are in the Self Attention module of Layer 0 in the BERT encoder stack.
GPT LoRA
I also added another composite around the LoRA layer and the βaddβ operation so that I can drop it as one single modifier.
In the implementation, as covered here, the QKV layers are all stored as a single matrix in the GPT implementation (at least the one that Graphbook uses). These are split apart before being reshaped based on the number of attention heads, fromβ¦
[batch_size x num_tokens x hidden_size] β
[batch_size, num_heads, num_tokens, hidden_size/num_heads].
We can drop in those βAdd LoRA Layerβ blocks and direct the data flow through these blocks before being reshaped.
All implementation details on the LoRA layer are provided on Github.
Was this visualized implementation helpful? Did I get anything wrong? What do you want to see next? Let me know in the comments!
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI