KAN (Kolmogorov-Arnold Networks): A Starter Guide 🐣
Author(s): JAIGANESAN
Originally published on Towards AI.
KAN (Kolmogorov-Arnold Networks): A Starter Guide 🐣
I want to start with a question that has been on my mind lately. Is the AI community genuinely expecting a groundbreaking innovation? Iβm asking this because of two recent scenarios that caught my attention.
Firstly, when MAMBA was released in late 2023, many people were excited. The question on everyoneβs mind was, βIs MAMBA the next big thing?β Would it revolutionize traditional transformer architecture?
Fast forward to the end of April 2024, when KAN was released. Once again, the same question echoed through the AI community, βIs KAN going to replace neural networks and MLP?β (Please Comment your thoughts on this)
It seems to me that weβre all eagerly anticipating the next big innovation in the AI world.
Letβs get into the Article1
Itβs simply not possible to cover the entire spectrum of KAN in one article, from basics to advanced. However, my goal in this article is to give some basic understanding of KAN. We will explore the parts or building blocks of KAN in this Article.
Take a look at Image 1, which highlights the key differences between the Multi-Layer Perceptron (MLP) and Kolmogorov-Arnold Networks (KAN). One of the Important differences between these two is that MLP is built on the principles of the Universal Approximation Theorem, while KAN is backed by the Kolmogorov-Arnold Representation Theorem.
Before we dive into the world of Kolmogorov-Arnold Networks (KAN), itβs essential to have a solid understanding of Neural Networks. Letβs start with the basics of MLPs and then explore the parts and working mechanisms of KANs.
Universal Approximation Theorem (UAT) 🐾
The Universal Approximation Theorem (UAT) states that a neural network with a single hidden layer (With N neurons) can learn to represent any continuous function to an arbitrary degree of accuracy. This means that neural networks can approximate any real-world continuous function, no matter how complex, by adjusting the weights and biases within the network.
Letβs consider an example to illustrate this point. Imagine weβve trained a Convolutional Neural Network (CNN) to identify whether an image contains a car or not. While CNN can make predictions with a certain level of accuracy, itβs not perfect and may miss some images that do contain cars. Whatβs happening here is that the neural network is actually learning to approximate what a car looks like, rather than recognizing every possible instance of a car.
Multi-Layer Perceptron (ANN)
Convolution helps us to get the features of the car, Car structure, shape and etc. But the Linear layer is the one that learns that feature and classifies the Images.
Consider this (Image 2) has an Artificial Neural Network / Multi-Layer Perceptron / linear layer with an input layer (4 input units), one hidden layer (5 Neurons), and an output layer (One Neuron β Binary classification) to predict whether the image has a car or not. The x1, x2, x3, and x4 are features that come from the convolution layer. This inputs are fed into the hidden layer. Where the input features are getting into mathematical operations with weights and biases, then activation is applied for Non-Linearity, Then the Hidden Layer output is fed into the output layer, Which gives logits, then activation is applied, which gives the probability score, indicating the likelihood that the image contains a car.
To illustrate this, Letβs take a closer look at below 4 images.
Note: Numbers in the images are for illustration purposes only. The same operation you are familiar with Neural Network as you see in Image 2 A, will be illustrated in the Below images. I am trying to change your perspective a little bit.
Letβs think we have 5 images, and we want to predict whether the image has a car or not. In this example, I will illustrate how prediction is made with the pre-trained model. The input Features have 5 image features that come from the convolution layer. This input feature (5,4) is multiplied with a transposed hidden layer weight matrix (5,4), resulting in an output (5,5) matrix, as shown in image 3. The Hidden layer gives 5 features to each image Because the Hidden Layer has 5 neurons.
With the Output(5,5) the bias vector (1,5) is added to each row and gives Output O1 (5,5), To introduce Non-Linearity the ReLU activation function has been applied that changes the Negative numbers to 0 as shown in Image 4.
Then, the hidden layer output O1 is Multiplied with the Transpose output layer weight Matrix W1(1,5), resulting in output (5,1) as shown in image 5.
Next, we take the output (5,1) and add the output layer bias vector B2 (1,1) to it. This is done through a process called broadcasting, which results in the Output Layer weight Matrix O2 (5,1). Then, we apply the Sigmoid activation function to O2, which converts the logits into a probability score. This gives us the final output, where any value less than 0.5 is classified as 0, and any value greater than or equal to 0.5 is classified as 1. Looking at the results, we can see that the first three images have cars in them.
I hope You understand UAT and Neural Networks. Letβs dive into the world of KAN. To See KAN as a Full picture, we need to understand its parts. We start with the Foundational theorem.
Kolmogorov-Arnold Representation Theorem
Kolmogorov-Arnold Representation Theorem backs Kolmogorov-Arnold Networks. Kolmogorov-Arnold representation theorem states that if f is a multivariate ( Multiple Feature vectors) continuous function on a bounded domain, then f can be written as a finite composition of a continuous function of a single variable and the binary operation of addition. More specifically, for a smooth f : [0,1]^n β R, Where Ο_p,q : [0,1] β R and Ο_q: R β R. In a sense, they showed that the only true multivariate function is addition (β) since every other function can be written using univariate functions (x_p) and sum.
B spline is an important concept in KAN, to understand B splines, we need to understand Bezier curves.
Bezier Curve
Both Bezier Curves and B-splines are Parametric curves. If you have studied Computer graphics, animation, or physics, you may have come across this term. I believe you are familiar with linear regression algorithms. In linear regression algorithms, the line or hyperplane is created by coefficients (Parameters). The same operations happen here. A Bezier curve has a set of control points(Learnable) that makes the curve. The Curve is the Linear Combination of these control points weighted by Bernstein polynomials. We will see everything below 😃.
If we have two points, we use a straight line to connect, if we have three points, we need to use a quadratic equation, if we have four points, Cubic equations, so We need a Polynomial equation to connect. If we have N points and we want to connect or draw a line that passes through these points, We need an equation with a polynomial of degree N-1.
If we have lots of points and we want the line (Polynomial line) to pass through all these points (data fitting), the lines will have many peaks and valleys, making the resulting curve oscillatory as shown in image 8.
One major issue with this method is that it can be very computationally expensive. When weβre dealing with polynomial equations, we know exactly how many operations will be required. However, if weβre working with a large number of points β hundreds or even thousands β the equation becomes extremely complex, making it very costly in terms of computational resources.
To overcome this challenge, we need to find a way to simplify these lines and make them less computationally intensive. This is where the Bezier curve comes in. The Bezier curve is a powerful tool that helps create a smooth line between points. As I mentioned earlier, the Bezier curve is a parametric curve. All the coordinates of the curve are dependent on an independent variable βtβ, which ranges from 0 to 1.
When it comes to the Linear Bezier Curve, things are relatively straightforward. Given two distinct points, p0 and p1, a linear Bezier curve is essentially the straight line that connects these two points.
The Mathematical representation of this Linear Bezier curveβ¦
The quantity(B(t)) p1-p0 represents the displacement vector from the start to the Endpoint.
Quadratic Bezier Curvesβ Mathematical Representationβ¦
The explicit form of the quadratic curve is:
Same for cubic Bezier Curvesβ Mathematical Representationβ¦
The explicit form of the cubic curve is:
If you donβt understand, bear with me. In the end, everything will make sense. Iβm trying to explain the concept of KANβs learnable function, which is a bit different from what we have seen with traditional Multi-Layer Perceptrons (MLPs).
In MLPs, we use learnable parameters and non-linear activation functions. But KAN takes a different approach. Instead, it uses a learnable non-linear function. You can see this difference in Image 1. Ok, Now Letβs get back to the Bezier curve 😁
Take time to check the Animated high order Bazier curve on Wikipedia.org. 📌 [Highly Recommended] It will help you understand more about Bezier curve interpolation.
Image 13 shows the smooth curve between p0 and p3 (4 points). It only parses from the first point and the last point and interpolates between intermediate points. It uses the recursive calculation to interpolate the points, as shown in Images 12 A and 12 B. As time step (t) changes or moves, the curve changes.
The Recursive calculation is computationally expensive, and we can calculate the Bezier curve without recursive calculation using the below function
(n i) is Binomial Co-efficient. It gives the coefficients of the Binomial when it is raised to the n_th power.
If we have n points, we can find the n-1 degree curve using this formula.
We can interpret the Bezier curve with the Bernstein Basis polynomial with the Basis function.
Image 15: The basis functions on the range t in [0,1] for cubic Bezier curves: blue: y = (1 β t)Β³, green: y = 3(1 β t)Β²t, red: y = 3(1 β t)tΒ², and cyan: y = tΒ³.
Image 14 C Formula for 4 points will result in polynomials like in Image 15. When the time step (x-axis) is 0, the only point 0 is contributing to the final curve (B(t)). When the time step moves, the contribution changes based on the points. Blue line y = (1 β t)Β³ β Polynomial associated with the first point (Image 13). When t is 0 (x-axis), y will be 1. When x reaches 1, the contribution of the first point will become 0. Actually, these points control the curve; thatβs why it's got their names as control points. When time moves from the 1st point, we get closer to the second point where the blue and green lines meet. Same as for all the points. Then, at the end, it reaches the last point at time step 1 [0,1]. This is how the 4 points contribute to the final curve.
However, there is a problem with these Bezier curves. If we have n points, we need a Bezier curve with n-1 degrees, which is also computationally complex. And there is no local control in degree n Bezier curves β meaning that any change to a control point requires recalculation and thus affects the aspect of the entire curve. If the curve changes in one place, it will affect it in all the places. To solve this and Make it more efficient, researchers introduced B-Spliens. (Finally, B spline 😴😴)
B spline (Basis spline)
Instead of a Bezier curve with 50 points, We make 10 Bezier curves with each 5 points and stitch them together. How is it 😜? This is how B splines are created.
In a B-spline curve, only a specific segment of the curve shape changes or is affected by the changing of the corresponding location of the control points. If we have n points, and if we set the degree of the B-spline curve as k, we will have n-k bezier curve.
For example, if we have 50 points, the degree of the B-spline curve is 4 (5 Control points for each Bezier curve) we will have a 46 Bezier curve. Points where they meet are called knots.
Points 1,2,3,4,5 β Bezier Curve 1
Points 2,3,4,5,6 β Bezier Curve 2
Points 3,4,5,6,7 β Bezier curve 3
β¦
points 46,47,48,48,50 β Bezier curve 46.
In this scenario, each Bezier curve is controlled by local control points; changes made in one control point do not affect the whole curve. The formula for the Basis function is defined in terms of basis function and control points. The B spline curve is a piece-wise defined polynomial function that is smooth and flexible.
Image 16: C(t) is the B-spline curve. N_i,p (t) are the B-spline basis function of degree p. p_i are the control points. t is the time step parameter [0,1].
B-spline Basis functions: For p=0 (Zero degree), the N_i,p (t) is
if t_i β€ t < t_i+1 will be 1, otherwise 0.
For p>0 (p degree)
Letβs take the example of 7 points with a degree of 3. So 7β3 = 4. We will have a 4-basis bezier curves. The Red dot is the example knot between the bezier curves.
Learnable Functions
Before we delve into the world of KAN (Kolmogorov Arnold Networks), letβs take a step back and explore another crucial topic. In this section, weβll discover how B-splines are used as a learnable function. Take a moment to glance at Image 1 and examine the shallow formula of KAN, where youβll notice Ξ¦(x). This is precisely what is represented in Equation 1 of Image 19.
The Equation 1 in Image 19 has two parts. b(x) and spline (x). Only the spline(x) is learnable. The activation function Ξ¦(x) is actually the sum of the basis function b(x) and the spline function spline(x). If we take a closer look at Equation 3 in Image 19, we can see that C_i represents the control point (Learnable), while B_i denotes the basis function of the B-spline.
In KAN there are in total O(NΒ² L(G + k)) βΌ O(NΒ² LG) parameters. Where L means the number of layers in the KAN. N means the number of functions in One Layer (Width). Where k is order and G is knots in the B-spline.
Kolmogorov Arnold Networks (KAN)
We know the Kolmogorov Arnold Representation theorem backs the Kolmogorov Arnold Networks.
We Have seen the what Ξ¦(x). In Image 20 the n means the number of features in the vector x. p is the input feature, and q is the 2nd layer input feature that comes from the 1st layer functions.
Image 21: Left β Notations of activations that flow through the KAN. Right β an activation function is parameterized as a B-spline, which allows switching between coarse-grained and fine-grained grids.
If we have 2 (p) features in vector x, the n will be 2 so 2n+1 will be 5, So in the First Layer (L) we will have 10 Non-Linear functions as shown in image 21 [Ο_0,1,1β¦]. That will give 5 features as output( 2n+1). These 5 features are given to 5 (q) Non-Linear functions [Ο_1,1,1β¦] that will result in single output x_2,1.
I want you to recall the Kolmogorov-Arnold Representation theorem β Kolmogorov-Arnold representation theorem states that if f is a multivariate ( Multiple Feature vectors) continuous function on a bounded domain, then f can be written as a finite composition of a continuous function of a single variable and the binary operation of addition.β According to this theorem, each variable passes through the Non-Linear functions and is Added. For example, the x_1,1 is the addition of the output from Non-Linear functions Ο_0,1,1 and Ο_0,2,1. The Non-Linear function Ο_0,1,1 processes the 1st feature in vector x (x_0,1). Respectively Ο_0,2,1 processes the second feature in vector x (x_0,2). We have also seen how the non-linear function works.
In MLP, the interpretation is very difficult Because there will be millions or billions of parameters. However, in KAN, we can interpret the learned function as shown in image 22.
The Learnable function could be anything square of input, sin function, or exponential function. Control points learn and adjust the curve for the function.
And thatβs a wrap! I hope Iβve explained the basic concepts of KAN and made it more understandable for you. Iβve put in a lot of effort to break down the basics in a simple and clear way, so Iβd love it if you could show your appreciation 👏. If youβre interested in learning more about KAN and its building blocks, I highly recommend checking out the references Iβve provided.
Thanks for reading this article 🤩. If you found my article useful 👍, give Clapssss👏! Feel free to follow 😉 for more insights.
Letβs stay connected and explore the exciting world of AI together!
Join me on LinkedIn: linkedin.com/in/jaiganesan-n/ 🌍❤οΈ
References:
[1] Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, KAN: Kolmogorov-Arnold Networks (2024) Research Paper.
[2] Bezier Curve Explanation. Wikipedia.org [Highly Recommended]
[3] Binomial Coefficients Wikipedia.org
[4] B-spline Wikipedia.org
[5] B-Spline GeeksforGeeks.org [Recommended]
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI