Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Read by thought-leaders and decision-makers around the world. Phone Number: +1-650-246-9381 Email: [email protected]
228 Park Avenue South New York, NY 10003 United States
Website: Publisher: https://towardsai.net/#publisher Diversity Policy: https://towardsai.net/about Ethics Policy: https://towardsai.net/about Masthead: https://towardsai.net/about
Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Founders: Roberto Iriondo, , Job Title: Co-founder and Advisor Works for: Towards AI, Inc. Follow Roberto: X, LinkedIn, GitHub, Google Scholar, Towards AI Profile, Medium, ML@CMU, FreeCodeCamp, Crunchbase, Bloomberg, Roberto Iriondo, Generative AI Lab, Generative AI Lab Denis Piffaretti, Job Title: Co-founder Works for: Towards AI, Inc. Louie Peters, Job Title: Co-founder Works for: Towards AI, Inc. Louis-François Bouchard, Job Title: Co-founder Works for: Towards AI, Inc. Cover:
Towards AI Cover
Logo:
Towards AI Logo
Areas Served: Worldwide Alternate Name: Towards AI, Inc. Alternate Name: Towards AI Co. Alternate Name: towards ai Alternate Name: towardsai Alternate Name: towards.ai Alternate Name: tai Alternate Name: toward ai Alternate Name: toward.ai Alternate Name: Towards AI, Inc. Alternate Name: towardsai.net Alternate Name: pub.towardsai.net
5 stars – based on 497 reviews

Frequently Used, Contextual References

TODO: Remember to copy unique IDs whenever it needs used. i.e., URL: 304b2e42315e

Resources

Take our 85+ lesson From Beginner to Advanced LLM Developer Certification: From choosing a project to deploying a working product this is the most comprehensive and practical LLM course out there!

Publication

Latest   Machine Learning

Unraveling the Black Box: Explainability in Generative AI — Part 1

Author(s): saeed garmsiri

Originally published on Towards AI.

Photo by Kelli McClintock on Unsplash

Hey everyone, thanks for reading another one of my articles! Remember in the last article we did a deep dive into SHAP and LIME? Well, get ready because we’re about to embark on another exciting exploration of explainable AI, this time focusing on Generative AI.

Before we dive into the world of explainability in GenAI, it’s worth noting that the tone of this article, like its predecessor, is intentionally casual and approachable. Peppered with humor and relatable analogies, these articles are designed for individuals who are serious about understanding AI explainability but prefer learning in a more conversational, less formal environment. By employing this approach, we aim to make complex concepts more digestible for tech-savvy professionals, data scientists, and advanced students who appreciate a blend of depth and accessibility. This style allows us to explore intricate AI topics without sacrificing technical accuracy, making the content engaging for those who might be intimidated by overly academic texts. Now that you know the piece’s tone and structure, let’s begin diving into explainable GenAI.

Introduction

Gen AI has made remarkable strides in recent years, producing human-like text, generating realistic images, writing Medium articles :), and even developing coherent code. However, as these systems become more complex and influential, a critical question arises: How can we understand and interpret the decisions made by these AI models? This is where the concept of explainability comes into play.

Explainability, also known as interpretability or transparency, refers to the ability to understand and describe how an AI system arrives at its outputs. In the context of Gen AI, it involves making the internal workings of complex models more accessible and comprehensible to humans. This article delves into the importance of explainability, its challenges, and various techniques used to achieve it in Gen AI systems.

The Importance of Explainability in Gen AI

  1. Trust and Adoption: As AI systems increasingly impact our daily lives, users need to trust their decisions. Explainable AI fosters this trust by providing insights into the decision-making process.
  2. Regulatory Compliance: Many industries, such as healthcare and finance, require transparent decision-making processes. Explainable AI helps meet these regulatory requirements.
  3. Debugging and Improvement: Understanding how a model works allows developers to identify and fix errors, biases, or unexpected behaviors more effectively.
  4. Ethical Considerations: Explainability helps in identifying and addressing potential biases or unfair treatment in AI systems.
  5. Knowledge Discovery: Insights gained from explainable models can lead to new scientific discoveries or business insights.

Challenges in Achieving Explainability

  1. Model Complexity: Modern Gen AI models, such as large language models (LLMs) or deep neural networks, often contain billions of parameters, making them inherently difficult to interpret.
  2. Non-linearity: Many AI models use non-linear functions, which can be challenging to explain in human-understandable terms.
  3. High-dimensional Data: Gen AI models often work with high-dimensional data, making it difficult to visualize or comprehend all relevant factors.
  4. Trade-off with Performance: In some cases, more explainable models may sacrifice some performance compared to their black-box counterparts.
  5. Varying Stakeholder Needs: Different stakeholders (e.g., developers, end-users, regulators) may require different levels and types of explanations.

During the rest of this article, we’ll be comparing LIME and SHAP with other explainability techniques using a GPT2 use case. We utilize GPT-2 since it is an open-source generative model. Also, for each explainability technique, we develop a simple case study using Python. For each explainability technique, we develop a simple case study using Python. Feel free to skip the code snippets if you find them too detailed to follow.

Techniques for Explainability in Gen AI

  1. SHAP (SHapley Additive exPlanations) Explained here

Now, let’s talk about SHAP. When I first encountered this technique, I was honestly a bit overwhelmed. All those Shapley values and game theory concepts? Yikes! But stick with me here, because once I got the hang of it, SHAP became my go-to explainability tool. Let me break it down for you the way I wish someone had done for me…

SHAP (SHapley Additive exPlanations) is a technique that helps us understand how different features in our data contribute to an AI model’s predictions. Here’s how it works:

  1. It looks at every possible combination of features.
  2. For each combination, it calculates how much the prediction changes compared to the average prediction.
  3. It then averages these changes for each feature across all combinations.
  4. The result is a set of SHAP values that show exactly how much each feature contributed to the final prediction.

SHAP is powerful because it’s fair (considers all feature combinations), consistent (gives reliable results), and can work with any type of machine learning model. It helps us understand which features are driving the model’s decisions, making complex AI systems more interpretable and trustworthy.

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import matplotlib.pyplot as plt

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

def get_model_response(prompt):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
output = model.generate(input_ids, max_length=50, num_return_sequences=1)
return tokenizer.decode(output[0], skip_special_tokens=True)

def token_importance(text):
words = text.split()
base_response = get_model_response(text)
importances = []

for i in range(len(words)):
modified_text = ' '.join(words[:i] + words[i+1:])
modified_response = get_model_response(modified_text)

importance = 1 - (len(set(modified_response.split()) & set(base_response.split())) / len(set(base_response.split())))
importances.append(importance)

return importances

text = "Describe a high-end smartphone"

importances = token_importance(text)

plt.figure(figsize=(12, 6))
plt.bar(text.split(), importances)
plt.title('Word Importance for Input Text')
plt.xlabel('Words')
plt.ylabel('Importance')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

print("Generated text:", get_model_response(text))
Figure 1: Word Importance Created by SHAP
Generated text: 
Describe a high-end smartphone, such as the Galaxy S6 or S6 Edge, that has a Snapdragon 835 processor and a 4.5-inch display.

Code Interpretation:

  • Methodology: SHAP explains the output of a model by computing Shapley values from cooperative game theory, capturing the contribution of each feature to the prediction.
  • Model Prediction: SHAP begins by obtaining predictions from the GPT-2 model for the input text.
  • Feature Importance(aka Shapley Values): SHAP computes Shapley values, which quantify the impact of each token or word in the input on the model’s prediction.
  • Visualization: SHAP typically visualizes these Shapley values to show the impact of each feature on the model output. In this example, it is visualized as a bar plot. The bar plot ranks words in the input prompt based on their impact on the model’s response. You can see that “Describe” has the highest impact.
  • The generated text shows what GPT-2 produces, given the full prompt. This gives you an idea of how the model interprets and expands on the input.

Also, it’s good to remember that this is a simplified approach to understanding word importance. It doesn’t capture all the nuances of how GPT-2 processes text, but it provides a basic explainability of word importance in the given prompt.

2. LIME (Local Interpretable Model-agnostic Explanations) Explained here

LIME is a technique that explains individual predictions of any machine learning classifier in an interpretable and faithful manner. It works by approximating the behavior of the complex model locally with an interpretable model.

In simpler terms, LIME is like a friendly translator between complex AI-speak and human language. It takes a single prediction from our AI model and says, “Hey, let me explain why the model made this decision in a way that makes sense to you.”How does it work? Imagine you’re trying to figure out why your friend likes a particular song. You might play around with different versions of the song — changing the tempo, removing instruments, altering lyrics — and see how your friend reacts to each change. LIME does something similar with our AI model.Here’s the cool part: LIME creates a simple, interpretable model that mimics how our complex AI behaves for a specific input. It’s like creating a simple sketch that captures the essence of a detailed painting.Let’s see LIME in action with our GPT-2 model:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import matplotlib.pyplot as plt
from lime.lime_text import LimeTextExplainer
import numpy as np

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', padding_side='left')

# Set the pad token id to eos token id
tokenizer.pad_token = tokenizer.eos_token

def get_model_response(prompt):
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=50)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
with torch.no_grad():
output = model.generate(input_ids, attention_mask=attention_mask, max_length=50, num_return_sequences=1)
return tokenizer.decode(output[0], skip_special_tokens=True)

def gpt2_predict(texts):
inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=50)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask, return_dict=True)
logits = outputs.logits
# Use the mean of logits as a "score" for LIME
scores = logits.mean(dim=-1).detach().cpu().numpy() # Calculate mean across vocab dimension
return scores

text = "Describe a high-end smartphone"

generated_text = get_model_response(text)

explainer = LimeTextExplainer()

exp = explainer.explain_instance(text, gpt2_predict, num_features=10, num_samples=100)

fig, ax = plt.subplots(figsize=(10, 6))
labels = exp.available_labels()
ax.barh(range(len(exp.as_list(label=labels[0]))), [x[1] for x in exp.as_list(label=labels[0])], align='center')
ax.set_yticks(range(len(exp.as_list(label=labels[0]))))
ax.set_yticklabels([x[0] for x in exp.as_list(label=labels[0])])
ax.invert_yaxis() # Invert y-axis to have most important features at the top
ax.set_xlabel('Feature Importance')
ax.set_title('LIME Explanation for GPT-2 Output')
plt.tight_layout()
plt.show()

print("Generated text:", generated_text)
Figure 2: Word Importance Created by LIME
Generated text: 
Describe a high-end smartphone, such as the Galaxy S6 or S6 Edge, that has a Snapdragon 835 processor and a 4.5-inch display.

As shown in the figure, LIME’s feature importance graph differs significantly from that of SHAP. This is because these techniques employ different methods to explain model outputs. Let’s investigate the model, it’s output and LIME’s approach to explainability:

  • Methodology: LIME explains model predictions by training an interpretable model locally on perturbed instances of the input data.
  • Model Prediction: It starts by obtaining predictions from the black-box model (in this case, GPT-2) for the input text.
  • Explanation Generation:
    a. Perturbation: LIME perturbs the input text to create a dataset of modified instances.
    b. Prediction: It predicts with the original model for these modified instances.
    c. Feature Importance: LIME then trains an interpretable model (e.g., linear model) to fit these predictions, determining feature importance.
  • Visualization: LIME typically visualizes feature importance scores, indicating which words or tokens in the input text contributed most to the model’s output.

3. Attention Mechanisms

Attention mechanisms allow models to focus on specific parts of the input when producing each part of the output. They’re particularly useful in sequence-to-sequence tasks like machine translation or text summarization.

Imagine you’re reading a complex sentence. As you read each word, your brain doesn’t give equal importance to all the other words in the sentence. Instead, it focuses more on certain words that are most relevant to understanding the current word or predicting the next one. This is essentially what attention mechanisms do in neural networks.

In the context of language models like GPT-2:

  1. For each word (or token) the model is processing, it calculates how much “attention” to pay to every other word in the input.
  2. This attention is represented as a set of weights — higher weights mean more attention.
  3. The model uses these weights to create a weighted sum of all the word representations, focusing more on the important words and less on the irrelevant ones.
  4. This process happens in multiple “heads” (different sets of attention weights) and multiple layers, allowing the model to capture different types of relationships between words.

Here’s a simplified implementation of an attention mechanism:

import torch
from transformers import GPT2Model, GPT2Tokenizer
import matplotlib.pyplot as plt
import seaborn as sns

model_name = 'gpt2'
model = GPT2Model.from_pretrained(model_name, output_attentions=True)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

def get_model_response_and_attention(prompt):
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=50)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)

generated_text = tokenizer.decode(outputs.last_hidden_state[0].argmax(dim=-1), skip_special_tokens=True)

# Extract attentions from the last layer (12th layer in GPT-2)
attentions = outputs.attentions[-1][0]
return generated_text, attentions, input_ids

text_prompt = "Describe a high-end smartphone."

generated_text, attentions, input_ids = get_model_response_and_attention(text_prompt)

# Convert input_ids to tokens
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

# Visualize attention weights for the last layer
layer_idx = -1 # Last layer index
attention = attentions[layer_idx]

plt.figure(figsize=(12, 10))
sns.heatmap(attention, cmap="viridis", xticklabels=input_tokens, yticklabels=input_tokens)
plt.title(f'Attention Weights for Layer {layer_idx}')
plt.xlabel('Input Tokens')
plt.ylabel('Input Tokens')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=45)
plt.tight_layout()
plt.show()

print("Generated text:", generated_text)
Figure 3: Token relations using Attention method
Generated text: 
Describe a high-end smartphone, such as the Galaxy S6 or S6 Edge, that has a Snapdragon 835 processor and a 4.5-inch display.

Now let’s check our observation from the code:

  • Methodology: This approach visualizes attention weights in a neural network, such as GPT-2, to understand how the model processes input sequences.
  • Model Attention: GPT-2, like other transformer models, computes attention weights between input tokens at each layer of the network.
  • Attention Extraction: These attention weights indicate how much each token attends to every other token in the input sequence.
  • Visualization(Heatmap): Attention weights are visualized as a heatmap, where each cell represents the strength of attention between two tokens.
  • Insight: This visualization helps interpret which tokens the model focuses on more during processing, providing insights into the model’s decision-making process.

3.1 LIME & SHAP vs Attention Mechanisms:

3.1.1 Nature of Explanation:

  • LIME and SHAP focus on explaining individual predictions using perturbations or Shapley values, respectively, providing insights into local and global feature importance.
  • Attention Visualization shows how the model processes the input sequence through attention mechanisms, focusing more on understanding the model’s internal workings.

3.1.2 Interpretability:

  • LIME and SHAP provide direct feature importances or contributions, aiding in interpretability for non-technical stakeholders.
  • Attention Visualization is more technical, providing insights into model mechanics rather than direct feature importance.

3.1.3. Applicability:

  • LIME and SHAP are model-agnostic and applicable to a wide range of models.
  • Attention Visualization is specific to transformer-based models like GPT-2, leveraging their unique attention mechanisms.

Think of LIME and SHAP as detectives trying to solve a mystery. They’re looking at clues (features) to figure out why the AI made a certain decision. They’ll tell you things like, “The word ‘high-end’ was really important in deciding this was about an expensive smartphone.”Now, attention visualization is more like watching the AI’s thought process in action. It’s as if we could see the gears turning in the AI’s brain, showing us which words it’s focusing on as it reads the sentence. Both approaches are useful but in different ways. It’s like having a map of a city (LIME and SHAP) versus actually walking through the streets (attention visualization). The map gives you an overview, while the walk lets you experience the details. Choosing between these methods is like picking the right tool for a job. Sometimes you need a bird’s-eye view (LIME and SHAP), and other times you want to get your hands dirty and see the inner workings (attention visualization). It all depends on what you’re trying to understand about your AI model and how deep you want to dive.

4. Counterfactual Explanations

Counterfactual explanations provide insight into how the model’s output would change if the input were slightly different. They answer the question: “What would need to change for the model to produce a different output?”

Let’s further explore this concept with the following example:)

Let’s say you’re a chef who just cooked a delicious pasta dish. A food critic tastes it and gives it a high rating. Now, you’re curious: what would have happened if you had changed one ingredient? Would the rating still be high if you had used a different type of cheese, or if you had added more salt? This is essentially what counterfactual explanations do for machine learning models. In the context of language models:

  1. We start with an original input that produces a certain output.
  2. We then make small changes to this input and observe how the output changes.
  3. These changes help us understand which parts of the input are most crucial for the model’s decision.
  4. By doing this, we can identify what features or words are most influential in the model’s output.

Counterfactual explanations are particularly useful because they provide actionable insights. They don’t just tell us what the model did, but what would need to change for the model to do something different. Now, let’s look at a simple code example using the “Describe a high-end smartphone” prompt:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import matplotlib.pyplot as plt
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
nltk.download('punkt')
nltk.download('stopwords')

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

tokenizer.pad_token = tokenizer.eos_token

def generate_text(prompt):
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=50)
with torch.no_grad():
outputs = model.generate(inputs.input_ids, max_length=100, num_return_sequences=1)
return tokenizer.decode(outputs[0], skip_special_tokens=True)

def calculate_similarity(text1, text2):
# Tokenize and remove stopwords
stop_words = set(stopwords.words('english'))
words1 = [word.lower() for word in word_tokenize(text1) if word.isalnum() and word.lower() not in stop_words]
words2 = [word.lower() for word in word_tokenize(text2) if word.isalnum() and word.lower() not in stop_words]

# Calculate Jaccard similarity
intersection = len(set(words1).intersection(words2))
union = len(set(words1).union(words2))
return intersection / union if union != 0 else 0

original_prompt = "Describe a high-end smartphone"
original_output = generate_text(original_prompt)

print(f"Original prompt: {original_prompt}")
print(f"Original output: {original_output}\n")

# Counterfactual examples
counterfactuals = [
"Describe a low-end smartphone",
"Describe a high-end laptop",
"Describe an average smartphone",
"Describe a high-end smartwatch"
]

similarities = []

for cf_prompt in counterfactuals:
cf_output = generate_text(cf_prompt)
similarity = calculate_similarity(original_output, cf_output)
similarities.append(similarity)
print(f"Counterfactual prompt: {cf_prompt}")
print(f"Counterfactual output: {cf_output}")
print(f"Similarity to original: {similarity:.2f}\n")

plt.figure(figsize=(12, 6))
plt.bar(counterfactuals, similarities)
plt.title('Similarity of Counterfactual Outputs to Original Output')
plt.xlabel('Counterfactual Prompts')
plt.ylabel('Similarity Score')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()
Figure 4: Counterfactual Explanation

This visualization helps us quickly see which changes to the prompt result in outputs that are more or less similar to the original. For example:

  • If changing “high-end” to “low-end” results in a low similarity score, it suggests that the “high-end” aspect significantly influences the model’s output.
  • If changing “smartphone” to “smartwatch” still results in a relatively high similarity score, it might indicate that the model uses similar descriptors for these two types of devices.

This plot provides a clear, visual way to understand the impact of different counterfactual changes, making it easier to interpret how the model’s output changes in response to input modifications.

4.1. Comparison with LIME and SHAP:

4.1.1 Similarities:

  • Model-Agnostic: Like LIME and SHAP, counterfactual explanations can be applied to various models, including GPT-2.
  • Local Interpretability: Counterfactuals focus on explaining individual predictions, providing insights into why a particular prediction was made.

4.1.2. Differences:

a. Methodology:

  • LIME and SHAP rely on perturbations or Shapley values to explain feature importance.
  • Counterfactuals directly modify the input to identify minimal changes for different predictions.

b. Visualization:

  • LIME and SHAP visualize feature importances or Shapley values, indicating the impact of each feature on the prediction.
  • Counterfactuals visualize changes needed to alter predictions, highlighting the sensitivity of the model to different parts of the input.

c. Applicability:

  • LIME and SHAP are effective for explaining feature-level importance across different models.
  • Counterfactuals are particularly useful for understanding how small changes in the input text can lead to different model predictions, offering insights into model sensitivity.

You remember our detectives, LIME and SHAP, and how they work, right?

Now, counterfactual explanations are more like asking, “What if?” It’s like wondering how the story would change if a key character made a different choice. With AI, we’re asking, “What if we changed this word? How would that affect the AI’s decision?”Both approaches help us peek into the AI’s “thought process,” but in different ways. LIME and SHAP help us understand what the AI is focusing on right now, while counterfactuals let us explore alternative scenarios.

When it comes to choosing a tool sometimes you need to know what’s important right now (LIME and SHAP), and other times you want to explore “what ifs” (counterfactuals). It all depends on what you’re trying to understand about your AI and how you want to explain its decisions. Remember, the goal here is to make our AI less of a mystery and more of a transparent partner. Whether we’re identifying key features or exploring alternative scenarios, we’re working towards AI systems that we can understand and trust.

5. Layer-wise Relevance Propagation (LRP)

LRP is a technique that explains deep neural network decisions by decomposing the output in terms of input features. It works by backpropagating the relevance of neurons from the output layer to the input layer.

Ok! I know it’s a long article, but let’s do some imagination again to make it easier to read;)

Imagine you’re trying to understand why a complex machine made a certain decision. Instead of just looking at the final output, you want to trace back through all the gears and mechanisms to see how each part contributed to the final result. This is essentially what LRP does for neural networks. In the context of language models:

  1. We start with the model’s output and work backwards through the layers of the network.
  2. At each layer, we distribute the relevance (importance) of each neuron to the neurons in the previous layer that contributed to its activation.
  3. This process continues all the way back to the input layer, showing us how much each input feature (in our case, each word) contributed to the final output.
  4. The result is a “heat map” of relevance, showing which parts of the input were most important for the model’s decision.

LRP is particularly useful because it provides a fine-grained explanation of the model’s decision-making process, taking into account the entire network structure. Now, let’s look at a simplified code example using the same prompt use case. Note that implementing full LRP for a complex model like GPT-2 is quite involved, so we’ll use a simplified version that focuses on the embedding layer:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import matplotlib.pyplot as plt
import seaborn as sns

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

tokenizer.pad_token = tokenizer.eos_token

def generate_text_and_relevance(prompt):
inputs = tokenizer(prompt, return_tensors='pt')
input_ids = inputs['input_ids']

# Forward pass
with torch.no_grad():
outputs = model(input_ids, output_attentions=True)
logits = outputs.logits
attentions = outputs.attentions

# Get the predicted token
predicted_token_id = logits[0, -1].argmax()

# Compute relevance as the mean attention across all layers and heads
relevance = torch.mean(torch.stack(attentions), dim=(0, 1))
relevance = relevance[0, -1, :len(input_ids[0])] # Get relevance for the last token

# Normalize relevance
relevance = (relevance - relevance.min()) / (relevance.max() - relevance.min())

return tokenizer.decode(predicted_token_id), relevance.numpy()

prompt = "Describe a high-end smartphone"

# Generate text and compute relevance
predicted_token, relevance = generate_text_and_relevance(prompt)

# Tokenize the input
tokens = tokenizer.convert_ids_to_tokens(tokenizer(prompt)['input_ids'])

plt.figure(figsize=(12, 6))
sns.barplot(x=tokens, y=relevance)
plt.title(f'Token Relevance for Predicting: "{predicted_token}"')
plt.xlabel('Input Tokens')
plt.ylabel('Relevance Score')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

print(f"Predicted next token: {predicted_token}")
Figure 5: LRP Graph

This simplified LRP approach provides insight into how the model is using different parts of the input to make its predictions. It’s a powerful tool for understanding the model’s decision-making process and can help identify potential biases or unexpected behaviors in the model. Remember, this is a simplified version of LRP focused on the embedding layer. A full LRP implementation would propagate relevance through all layers of the network, providing an even more detailed explanation of the model’s behavior.

5.1. Comparing LRP with LIME and SHAP:

5.1.1. Similarities:

  • Model-Agnostic: Like LIME and SHAP, LRP can be adapted to various models, including complex models like GPT-2, by analyzing model internals such as layer activations and weights.
  • Local Interpretability: LRP, LIME, and SHAP all focus on providing explanations at a local level, offering insights into why specific predictions were made for individual instances.

5.1.2. Differences:

a. Methodology:

  • LRP: Propagates relevance backwards through model layers, attributing importance to each input token based on how it influences model output through activations and weights.
  • LIME and SHAP: Rely on perturbations or Shapley values to explain feature importance by evaluating how predictions change with variations in input features.

b. Visualization:

  • LRP: Visualizes token-level relevance or attention weights across layers, indicating the contribution of each token to the model’s decision.
  • LIME and SHAP: Visualize feature importances or Shapley values, providing insights into which features (or tokens, in text) contribute most significantly to the model’s prediction.

c. Applicability:

  • LRP: Particularly effective for understanding token-level contributions and how different parts of the input text influence model decisions.
  • LIME and SHAP: Are versatile in explaining feature-level importance across different models and data types, focusing on global feature attributions.

Depending on what you need to understand about your model, you might choose one method over the other. LIME and SHAP give you a big-picture view of feature importance, while LRP offers a more detailed, token-level perspective. Together, they provide a comprehensive toolkit for making sense of complex AI models and their behaviors. This version maintains technical accuracy while making the explanation more relatable and easier to understand.

Conclusion

In this article which is the first piece of GenAI explainability, we investigated the following techniques.

  • SHAP and LIME are like having a translator, helping us understand AI decisions in human terms.
  • Attention Mechanisms show us where the AI is looking, much like following someone’s gaze to understand what they find important.
  • Counterfactual Explanations are the “what-ifs” of AI, helping us understand how different choices lead to different outcomes.
  • Layer-wise Relevance Propagation is like tracing a river to its source, showing us how information flows through the AI.

In the next piece, I’ll explain more explainability techniques, so stay tuned there’s more to come.

Resources

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓