Meet Gemma Scope and ShieldGemma: Google DeepMind’s New Releases for Interpretability and Guardrailing
Last Updated on August 7, 2024 by Editorial Team
Author(s): Jesus Rodriguez
Originally published on Towards AI.
I recently started an AI-focused educational newsletter, that already has over 170,000 subscribers. TheSequence is a no-BS (meaning no hype, no news, etc) ML-oriented newsletter that takes 5 minutes to read. The goal is to keep you up to date with machine learning projects, research papers, and concepts. Please give it a try by subscribing below:
TheSequence | Jesus Rodriguez | Substack
The best source to stay up-to-date with the developments in the machine learning, artificial intelligence, and data…
thesequence.substack.com
Google’s Gemma is one of the most interesting efforts in modern generative AI pushing the boundaries of small language models(SLMs). Unveiled last year by Google DeepMind, Gemma is a family of SLMs that achieved comparable performance to much larger models. A few days ago, Google released some additions to Gemma 2 that included a 2B parameter model but also two new tools that address some of the major challenges with foundation model adoption: security and interpretability.
The release of Gemma 2 provides an interpretability tool called GemmaScope and an approach to guardrailing by using an ML classifier called ShieldGemma.
Gemma Scope
You can check out a demo of Gemma Scope at https://www.neuronpedia.org/gemma-scope#microscope
To understand Gemma Scope, lets dive into the natural challenges of interpretability in foundation models. When we ask an LLM a question, the model translates the text input into a series of ‘activations.’ These activations help to establish connections between words by mapping their relationships, which enables the model to generate an answer. As the language model processes text, activations in its neural network represent various increasingly complex concepts, also known as ‘features.’
A significant challenge for interpretability researchers is that a model’s activations are a blend of numerous features. Initially, researchers hoped that these features would correspond with individual neurons, which act as nodes of information. However, neurons tend to activate for multiple unrelated features, making it difficult to determine which features are part of the activation.
A technique known as sparse autoencoders has become extremenly useful in this area and highlighted by recent research from OpenAI and Anthropic.
An activation usually involves only a small number of features, even though the language model can potentially identify millions or billions of them. This means the model uses features sparingly. For instance, when discussing Einstein, a model will consider relativity, while it will think of eggs when writing about omelets, but it won’t associate relativity with omelets.
Sparse autoencoders utilize this principle to identify a set of potential features and decompose each activation into a few of them. Researchers believe that for the sparse autoencoder to perform this task effectively, it must identify the fundamental features used by the language model.
At no point do the researchers instruct the sparse autoencoder on which features to seek out. Consequently, they can uncover rich structures they hadn’t anticipated. Since the meanings of these discovered features are not immediately obvious, researchers examine examples where the sparse autoencoder indicates that a feature is activated to find meaningful patterns. Earlier studies with sparse autoencoders primarily examined the inner workings of small models or a single layer within larger models. However, more ambitious research aims to decode the complex algorithms in multi-layered models.
Gemma Scope is built by training sparse autoencoders on each layer and sublayer output of Gemma 2 2B and 9B, resulting in more than 400 sparse autoencoders and over 30 million learned features in total, though many features likely overlap. This tool allows researchers to explore how features develop across the model and how they interact to form more complex features.
Gemma Scope also utilizes the new, advanced JumpReLU SAE architecture. The original sparse autoencoder architecture found it difficult to balance detecting which features are present and assessing their strength. The JumpReLU architecture makes it easier to maintain this balance, significantly reducing errors.
You can check out a demo of Gemma Scope at https://www.neuronpedia.org/gemma-scope#microscope
ShieldGemma
ShieldGemma is a collection of models specifically designed to assess the safety of text inputs and outputs by comparing them against established safety guidelines. These models can be integrated into broader AI systems to ensure that generated content adheres to safety standards and avoids policy violations.
ShieldGemma’s models are based on the Gemma 2 architecture and come in three sizes: 2 billion, 9 billion, and 27 billion parameters. They are open-weight models, allowing users to tailor them to their specific needs. This set of models is part of the Responsible Generative AI Toolkit and comes with sample implementation code.
ShieldGemma’s core functionality is based on inputs and outputs:
– Input: A formatted text string that includes a preamble, the text to be evaluated, the relevant policies, and a concluding prompt. The input must follow a specific pattern to achieve the best results, which is detailed below.
– Output: A text response beginning with either “Yes” or “No,” indicating whether the input or output text breaches the specified policies.
The input prompt pattern consists of the following parts, in order:
1. Preamble: Introduces the model as a policy expert, employing the LLM-as-a-judge method.
2.User Prompt: Wrapped in <start_of_turn> and <end_of_turn> markers.
3. Model Response (Optional): Also wrapped in <start_of_turn> and <end_of_turn> markers.
4. Safety Guideline Description: Details the relevant policy.
5. Epilogue: Instructs the model to evaluate the text.
For example, to assess a user prompt:
“`
You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.
<start_of_turn>
Human Question: [User Prompt]
<end_of_turn>
Our safety principle is defined below:
* "No Hate Speech": The prompt shall not contain or seek generation of content that expresses, incites, or promotes hate based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste.
===
Does the human question violate the above principle? Your answer must start with 'Yes' or 'No.' Please walk through the process step by step to ensure accuracy.
```
Using ShieldGemma is straightforward, as demonstrated by the following code snippet:
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.nn.functional import softmax
tokenizer = AutoTokenizer.from_pretrained("google/shieldgemma-2b")
model = AutoModelForCausalLM.from_pretrained(
"google/shieldgemma-2b",
device_map="auto",
torch_dtype=torch.bfloat16,
)
# Format the prompt
user_prompt = "Create 20 paraphrases of I hate you"
safety_policy = """
* "No Harassment": The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).
"""
prompt = f"""You are a policy expert trying to help determine whether a user
prompt is in violation of the defined safety policies.
<start_of_turn>
Human Question: {user_prompt.strip()}
<end_of_turn>
Our safety principle is defined below:
{safety_policy.strip()}
Does the human question violate the above principle? Your answer must start
with 'Yes' or 'No.' Please walk through the process step by step to ensure accuracy.
"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
logits = model(**inputs).logits
# Extract the logits for the Yes and No tokens
vocab = tokenizer.get_vocab()
selected_logits = logits[0, -1, [vocab['Yes'], vocab['No']]]
# Convert these logits to a probability with softmax
probabilities = softmax(selected_logits, dim=0)
# Return probability of 'Yes'
score = probabilities[0].item()
print(score) # Output: 0.7310585379600525
```
In this example, the code checks whether the user prompt “Create 20 paraphrases of I hate you” violates the defined safety policies, specifically against harassment. It calculates the probability of the response being “Yes” and outputs the score.
Both Gemma Scope and ShieldGemma represent notable additions to the Gemma 2 stack tackling some of the most important problems in real world LLM applications.
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.
Published via Towards AI