Master LLMs with our FREE course in collaboration with Activeloop & Intel Disruptor Initiative. Join now!

Publication

Token Masking Strategies for LLMs
Latest   Machine Learning

Token Masking Strategies for LLMs

Author(s): Fabio Yáñez Romero

Originally published on Towards AI.

Bert from Sesame Street is figuring out how to train BERT from zero. Source: DALL-E 3.

Token Masking is a widely used strategy for training language models in its classification variant and generation models. The BERT language model introduced it and has been used in many variants (RoBERTa, ALBERT, DeBERTa…).

However, Token Masking is a strategy within a larger group called Text Corruption. In the BART research paper, numerous experiments were performed to train an encoder-decoder generation model with different text corruption strategies.

Text corruption strategies. Source: “BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension”.

Before discussing the different techniques for Text Corruption, we will talk about the standard concepts of all Text Corruption methods in Large Language Models (LLMs).

From supervised to self-supervised

A large amount of text is used in the initial training of a language model with the objective that the model learns to represent the language correctly, storing this knowledge implicitly in its parameter weights.

This massive amount of text must have labels for training, as we must calculate the cross-entropy after processing the model input data with reference data. However, annotating such a large amount of data is unfeasible. Therefore, we resort to automatic label generation, turning the supervised problem into a self-supervised problem.

In this case, the corrupted sequence serves as the model’s training input, while all or part of the original sequence serves as the training data’s labels. This will depend on the nature of the model (encoder or encoder-decoder).

Corruption probability

With automatic labels, the model learns the label associated with each training example without annotating the data.

In Text Corruption (especially in Token Masking, Token Deletion, and Text Infilling), each word will likely be corrupted according to a fixed probability, usually around 15–20%. This probability is kept low so the model can learn the context of each sentence even if the sequence is corrupted.

Some Text Corruption techniques, such as Sentence Permutation or Document Rotation, do not focus on corrupting words with a certain probability. This allows them to be compatible with other corruption techniques, as discussed below.

Differences between Classification and Generation

When training language models with text corruption, the labels vary depending on whether it is a classification model (encoder-only) or a generation model (encoder-decoder).

In classification models, the labels used only pay attention to the corrupted areas of the input. So, if a word is masked in a whole sentence, the label will be the initial sequence, paying attention only to the corrupted sequence.

For generation models, as the model must be able to generate text continuously, the output label is the initial uncorrupted sequence, paying attention to the whole sequence itself.

Setup

Now that we have briefly introduced the points in common when training a language model with Text Corruption, let’s discuss the different techniques used to corrupt texts, giving examples with code in each case.

We will start with a document in the code examples to see how the different strategies work. We will use Stanza, a library developed by Stanford NLP with different NLP tools that are very useful for our preprocessing.

import stanza
stanza.download('en')

# Text used in our examples
text = "Huntington's disease is a neurodegenerative autosomal disease
results due to expansion of polymorphic CAG repeats in the huntingtin gene.
Phosphorylation of the translation initiation factor 4E-BP results in the
alteration of the translation control leading to unwanted protein synthesis
and neuronal function. Consequences of mutant huntington (mhtt) gene
transcription are not well known. Variability of age of onset is an
important factor of Huntington's disease separating adult and juvenile types.
The factors which are taken into account are-genetic modifiers, maternal
protection i.e excessive paternal transmission, superior ageing genes
and environmental threshold. A major focus has been given to the molecular
pathogenesis which includes-motor disturbance, cognitive disturbance and
neuropsychiatric disturbance. The diagnosis part has also been taken care of.
This includes genetic testing and both primary and secondary symptoms.
The present review also focuses on the genetics and pathology of Huntington's
disease."



# We will use a stanza model for getting each different sentence
# as an element of the list
nlp = stanza.Pipeline('en', use_gpu=False)
doc = nlp(text)
sentences = [sentence.text for sentence in doc.sentences]

Token Masking

Token Masking replaces random words in the text with <mask> to discover the masked word.

Token Masking example.

BERT introduced this strategy, the first and best-known Sequence Corruption strategy. It consists of corrupting an input sequence by masking random words, which will be used as labels during training.

In classification models, we can use the DataCollatorForLanguageModeling class directly from Huggingface transformers to generate the necessary labels, allowing us to train models like BERT or RoBERTa.

from transformers import AutoTokenizer, DataCollatorForLanguageModeling
import torch

def load_dataset_mlm(sentences, tokenizer_class=AutoTokenizer,
collator_class=DataCollatorForLanguageModeling,
mlm=True, mlm_probability=0.20
):
tokenizer = tokenizer_class.from_pretrained('google-bert/bert-base-uncased')
inputs = tokenizer(sentences, return_tensors='pt', padding=True,
truncation=True)

# Random masking configuration
data_collator = collator_class(
tokenizer=tokenizer,
mlm=mlm,
mlm_probability=mlm_probability
)

"""The collator expects a tuple of tensors, so you have to split
the input tensors and then remove the first dimension and pass it
to a tuple. """

tuple_ids = torch.split(inputs['input_ids'], 1, dim=0)
tuple_ids = list(tuple_ids)
for tensor in range(len(tuple_ids)):
tuple_ids[tensor] = tuple_ids[tensor].squeeze(0)
tuple_ids = tuple(tuple_ids)

# Get input_ids, attention_masks and labels for each sentence.
batch = data_collator(tuple_ids)
return batch['input_ids'], inputs['attention_mask'], batch['labels']


input_ids, attention_mask, labels = load_dataset_mlm(sentences)

"""
input_ids[0]:
tensor([ 101, 16364, 1005, 1055, 103, 2003, 1037, 103, 10976, 3207,
103, 25284, 103, 25426, 16870, 4295, 3463, 2349, 2000, 103,
1997, 26572, 18078, 6187, 2290, 17993, 1999, 1996, 5933, 7629,
103, 103, 102, 0, 0])

attention_mask[0]:
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0])

labels[0]:
tensor([ -100, -100, -100, -100, 4295, -100, -100, 11265, -100, -100,
6914, -100, 8285, -100, 2389, -100, -100, -100, -100, 4935,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
4962, 1012, -100, -100, -100]))

"""

Notice that the generated inputs_ids have an integer number for each token of the original text. A special token represents the masked words (in BERT, this token is 103). This special token varies depending on the language model used so that different tokenizers will return different identifiers of the attention mask.

Huggingface also assigns unique tokens with a different operation within the model, so tokens represented by “-100” indicate that the model should ignore them.

In the case of generation models like BART, we can implement the token masking strategy using the DataCollatorForLanguageModeling class. However, we must introduce minor changes to adapt the tags to a generation model.

from transformers import BartTokenizer, DataCollatorForLanguageModeling
import torch

def load_dataset_mlm(sentences, tokenizer_class=BartTokenizer,
collator_class=DataCollatorForLanguageModeling,
mlm=True, mlm_probability=0.20
):
tokenizer = tokenizer_class.from_pretrained('facebook/bart-base')
inputs = tokenizer(sentences, return_tensors='pt', padding=True,
truncation=True)

# Random masking configuration
data_collator = collator_class(
tokenizer=tokenizer,
mlm=mlm, # True for Masked Language Modelling
mlm_probability=mlm_probability # Chance for every token to get masked
)

"""The collator expects a tuple of tensors, so you have to split
the input tensors and then remove the first dimension and pass it
to a tuple. """

tuple_ids = torch.split(inputs['input_ids'], 1, dim=0)
tuple_ids = list(tuple_ids)
for tensor in range(len(tuple_ids)):
tuple_ids[tensor] = tuple_ids[tensor].squeeze(0)
tuple_ids = tuple(tuple_ids)

# Get input_ids, attention_masks and labels for each sentence.
batch = data_collator(tuple_ids)
batch['labels'] = inputs['input_ids']
return batch['input_ids'], inputs['attention_mask'], batch['labels']

input_ids, attention_mask, labels = load_dataset_mlm(sentences)

"""
input_ids[0]:
tensor([ 0, 38831, 2577, 1054, 18, 2199, 16, 10, 14913, 28904,
5777, 3693, 32226, 38868, 2199, 775, 528, 7, 2919, 9,
48052, 636, 230, 3450, 35315, 11, 5, 50264, 50264, 50264,
4, 2])

attention_mask[0]:
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1])

labels[0]:
tensor([ 0, 38831, 2577, 1054, 18, 2199, 16, 10, 14913, 28904,
5777, 3693, 32226, 38868, 2199, 775, 528, 7, 2919, 9,
48052, 636, 230, 3450, 35315, 11, 5, 8217, 24276, 10596,
4, 2])
"""

Here, each input token labels the token that would correspond to it regardless of whether it is masked. This is because, unlike classification models, the model must be able to generate a sequence of text based on the sequence given to the model. In the case of BART, the token representing each mask has the ID 50264.

Token Deletion

With Token Deletion, the model must learn at which exact position and which word is missing, so it must learn more features than with Token Masking.

Token Deletion example.

This strategy uses a different approach to masking. With a certain probability, a word is removed from the original sequence of the text, so the model must find the missing words and their positions. Standard masking does not learn the position, as the mask is already indicated at the model’s input.

def token_deletion(sentences, tokenizer_class=BartTokenizer, collator_class=DataCollatorForLanguageModeling, 
mlm=True, mlm_probability=0.20
):
tokenizer = tokenizer_class.from_pretrained('facebook/bart-base')
inputs = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)

data_collator = collator_class(
tokenizer=tokenizer,
mlm=mlm,
mlm_probability=mlm_probability
)

tuple_ids = torch.split(inputs['input_ids'], 1, dim=0)
tuple_ids = list(tuple_ids)
for tensor in range(len(tuple_ids)):
tuple_ids[tensor] = tuple_ids[tensor].squeeze(0)
tuple_ids = tuple(tuple_ids)

batch = data_collator(tuple_ids)

# We use the initial inputs as labels
batch['labels'] = batch['input_ids'].clone()

# We remove tokens with mask identifier and thus make token deletion
# Change the value to the mask identifier of the specific token model
# It is necessary to know the identifier of the mask token for
# that specific model
mask = batch['input_ids'] != 50264
initial_size = batch['input_ids'].size(1)
total_sentences = batch['input_ids'].size(0)

# When we remove the specific token, we must fill with the padding
# token otherwise the tensor size is not respected.
for i in range(total_sentences):
new_tensor = batch['input_ids'][i][mask[i]]
new_tensor = F.pad(new_tensor, (0, initial_size - new_tensor.size(0)), value=1)
batch['input_ids'][i] = new_tensor
attention_mask = batch['input_ids'][i] == 1
inputs['attention_mask'][i][attention_mask] = 0

return batch['input_ids'], inputs['attention_mask'], batch['labels']

input_ids, attention_mask, labels = token_deletion(sentences)

"""
input_ids[0]:
tensor([ 0, 38831, 2577, 1054, 2199, 14913, 28904, 3693, 32226, 38868,
2199, 775, 528, 7, 2919, 9, 23404, 636, 230, 35315,
11, 5, 24276, 10596, 4, 2, 1, 1, 1, 1,
1, 1])

attention_mask[0]:
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0])

labels[0]:
tensor([ 0, 38831, 2577, 1054, 50264, 2199, 50264, 50264, 14913, 28904,
50264, 3693, 32226, 38868, 2199, 775, 528, 7, 2919, 9,
23404, 636, 230, 50264, 35315, 11, 5, 50264, 24276, 10596,
4, 2])

"""

When training BART with Token Deletion, some text generation benchmarks show a slight improvement, where long sequences are used for question answering, summary generation tasks and conversational tasks.

Text Infilling

Text infilling allows the model to learn how many words can be in each mask. In comparison, previous approaches assume a single word per mask.

Text Infilling example.

Text infilling is similar to token masking, as we will use masks on the original text with a certain probability. In this case, the difference is that the masking can cover more than one word. When applying text infilling in BART, the masking is done with a Poisson distribution lambda = 3; this means that on average, every time text is masked in the sentence, three words will be masked in a single token mask, but as it is a probability distribution, there can be more or less masked words.

Poisson Distributions with different Lambda. Source: Wikimedia.

We will implement text infilling using the Numpy library and the tokenizer specific to our language model, in this case, BART.

import numpy as np
from transformers import BartTokenizer

def text_infilling(sentence, probability=0.2, poisson_lambda=3):
# We'll use a binary mask to determine which words to replace
mask = np.random.choice([0, 1], size=len(sentence), p=[1-probability, probability])

# Now we'll replace the chosen words with a mask token
# We'll also use a Poisson distribution to determine the length of the spans to mask
for i in range(len(mask)):
if mask[i] == 1:
span_length = np.random.poisson(poisson_lambda)
for j in range(span_length):
if i + j < len(sentence):
sentence[i + j] = "<mask>"

infilled_sentence = []
for token in range(len(sentence)):
if sentence[token] == "<mask>":
if token < len(sentence)-1:
if sentence[token+1] == "<mask>":
continue
else:
infilled_sentence.append(sentence[token])
else:
infilled_sentence.append(sentence[token])
else:
infilled_sentence.append(sentence[token])
return " ".join(infilled_sentence)

def text_infilling_input(masked_sentences, sentences, tokenizer_class=BartTokenizer):
tokenizer = tokenizer_class.from_pretrained('facebook/bart-base')
inputs = tokenizer(masked_sentences, return_tensors='pt', padding=True, truncation=True)
labels = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)
return inputs['input_ids'], inputs['attention_mask'], labels['input_ids']

input_ids, attention_mask, labels = text_infilling_input(masked_sentences, sentences)

"""
input_ids[0]:
tensor([ 0, 50264, 16, 50264, 2199, 775, 528, 50264, 48052, 636,
50264, 8217, 24276, 10596, 4, 2, 1, 1, 1, 1,
1, 1, 1])

attention_mask[0]:
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0])

labels[0]:
tensor([ 0, 38831, 2577, 1054, 18, 2199, 16, 10, 14913, 28904,
5777, 3693, 32226, 38868, 2199, 775, 528, 7, 2919, 9,
48052, 636, 230, 3450, 35315, 11, 5, 8217, 24276, 10596,
4, 2])

"""

Text infilling improves the results of the BART language model even more than Token Deletion, providing better generation in question answering, text summarisation, and conversational tasks.

Sentence Permutation

The input text to the language model is divided into sentences that are reordered randomly, having to find out the original order.

Sentence Permutation example.

In sentence permutation, it is vital to consider the number of sentences that fit in the model’s input sequence (in small models, the input sequence is between 512 and 1024 tokens). After determining the number of sentences that fit in the sequence, they must be separated into a list or array and selected randomly without repeating any of them, as we do in the example code.

# It selects the first "number_sentences" within a given set of "sentences" 
# and returns those sentences in a random order.
def sentence_permutation(sentences, number_sentences):
new_sentences = sentences[:number_sentences]
random.shuffle(new_sentences)
new_sentences = sentence_joiner(new_sentences)
return new_sentences

def permuted_data_generation(sentences: list, total_sentences: int):
training_sentences = []
training_labels = []
sentences_copy = sentences.copy()
# We can apply sentence_permutation a number of times equal to the
# size of the list - 1 to get an example with each new sentence in
# the text, removing the oldest one.
for _ in range(len(sentences)-total_sentences+1):
new_sentences = sentence_permutation(sentences_copy, total_sentences)
joined_sentences = sentence_joiner(sentences_copy[:total_sentences])
sentences_copy = sentences_copy[1:]
training_sentences.append(new_sentences)
training_labels.append(joined_sentences)

return training_sentences, training_labels


def permutation_training(sentences: list, sentences_labels: list,
tokenizer_class=BartTokenizer,
collator_class=DataCollatorForLanguageModeling,
mlm=True, mlm_probability=0.0
):
# We get input_ids and attention mask from the permuted sentences
input, attention_mask, _ = load_dataset_mlm(sentences, tokenizer_class, collator_class, mlm, mlm_probability)

# Labels from the original sentences
labels, _, _ = load_dataset_mlm(sentences_labels, tokenizer_class, collator_class, mlm, mlm_probability)

return input.squeeze(0), attention_mask.squeeze(0), labels.squeeze(0)

input_ids, attention_mask, labels = permutation_training(training_sentences, training_labels_sentences)

"""
input_ids[0]:
tensor([ 0, 38831, 2577, 1054, 18, 2199, 16, 10, 14913, 28904,
5777, 3693, 32226, 38868, 2199, 775, 528, 7, 2919, 9,
48052, 636, 230, 3450, 35315, 11, 5, 8217, 24276, 10596,
4, 2585, 33430, 8457, 9, 41419, 8217, 1054, 36, 119,
49491, 43, 10596, 37118, 32, 45, 157, 684, 4, 4129,
33839, 4405, 35019, 9, 5, 19850, 34939, 3724, 204, 717,
12, 21792, 775, 11, 5, 39752, 9, 5, 19850, 797,
981, 7, 15067, 8276, 37423, 8, 46282, 5043, 4, 2])

attention_mask[0]:
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1])

labels[0]:
tensor([ 0, 38831, 2577, 1054, 18, 2199, 16, 10, 14913, 28904,
5777, 3693, 32226, 38868, 2199, 775, 528, 7, 2919, 9,
48052, 636, 230, 3450, 35315, 11, 5, 8217, 24276, 10596,
4, 4129, 33839, 4405, 35019, 9, 5, 19850, 34939, 3724,
204, 717, 12, 21792, 775, 11, 5, 39752, 9, 5,
19850, 797, 981, 7, 15067, 8276, 37423, 8, 46282, 5043,
4, 2585, 33430, 8457, 9, 41419, 8217, 1054, 36, 119,
49491, 43, 10596, 37118, 32, 45, 157, 684, 4, 2])

"""

In the example, for each data input to the model, the sentence that came first in the original sequence is removed, while the following sentence is added before performing Sentence Permutation based on a fixed number of sentences to be selected. In this way, although we reorder the sentences in the input sequence, we maintain a context window where a new sentence appears for each new example, and the oldest sentence is deleted.

Document Rotation

When we rotate a document, we select a specific word and set it as the starting word, while all previous words are pasted at the end of the text.

Document Rotation example.

If we are going to apply Document Rotation, we must take into account the dimensions of each batch used. In the case of applying padding, this padding must not be rotated together with the rest of the document but must maintain its original position while the whole document rotates.

def sentence_joiner(sentences: list):
return ' '.join(sentences)

# With this function we gather as many sentences as we want to form the input data to the tokenizer.
def rotated_data_generation(sentences: list, total_sentences: int):
training_sentences = []
sentences_copy = sentences.copy()
for _ in range(len(sentences)-total_sentences+1):
new_sentences = sentences_copy[:total_sentences]
new_sentences = sentence_joiner(new_sentences)
sentences_copy = sentences_copy[1:]
training_sentences.append(new_sentences)
return training_sentences

# Apply this function over the rotated sentences from previous function
def document_rotation_training(sentences, tokenizer_class=BartTokenizer):
tokenizer = tokenizer_class.from_pretrained('facebook/bart-base')
tokens = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)
tokens['input_ids'] = tokens['input_ids'].squeeze(0)
tokens['labels'] = tokens['input_ids'].clone()

iterations = tokens['input_ids'].size(0)
for i in range(iterations):
# Get the attention mask and convert to list
attention_mask = tokens['attention_mask'][i].tolist()
# Calculate the position where padding starts
if 0 in attention_mask:
padding_start_position = attention_mask.index(0)
else:
padding_start_position = False
# We take into account the position of the padding so as not to rotate it along with the rest of the document.
if padding_start_position:
random_token = torch.randint(1, padding_start_position-1, (1,))
tokens['input_ids'][i] = torch.cat((tokens['input_ids'][i][0].unsqueeze(0), #initial token
tokens['input_ids'][i][random_token.item():padding_start_position-1], #from random to padding
tokens['input_ids'][i][1:random_token.item()], #from 1 to random
tokens['input_ids'][i][padding_start_position-1:-1],
tokens['input_ids'][i][-1].unsqueeze(0)), 0)

# If there is no padding, we rotate the document without taking the padding into account.
else:
random_token = torch.randint(1, tokens['input_ids'].size(0)-1, (1,))
tokens['input_ids'][i] = torch.cat((tokens['input_ids'][i][0].unsqueeze(0), #initial token
tokens['input_ids'][i][random_token.item():-1], #from random to end
tokens['input_ids'][i][1:random_token.item()],
tokens['input_ids'][i][-1].unsqueeze(0)), 0)
return tokens['input_ids'], tokens['attention_mask'].squeeze(0), tokens['labels']

data = rotated_data_generation(sentences, 3)
input_ids, attention_mask, labels = document_rotation_training(data)

"""
input_ids[2]:
tensor([ 0, 2433, 61, 32, 551, 88, 1316, 32, 12, 4138,
15557, 47605, 6, 22835, 2591, 939, 4, 242, 10079, 38422,
9235, 6, 10295, 22540, 14819, 8, 3039, 11543, 4, 347,
37347, 8457, 9, 41419, 8217, 1054, 36, 119, 49491, 43,
10596, 37118, 32, 45, 157, 684, 4, 41058, 4484, 9,
1046, 9, 23808, 16, 41, 505, 3724, 9, 18073, 18,
2199, 18246, 4194, 8, 13430, 3505, 4, 20, 2, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

attention_mask[2]:
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0])

labels[2]:
tensor([ 0, 347, 37347, 8457, 9, 41419, 8217, 1054, 36, 119,
49491, 43, 10596, 37118, 32, 45, 157, 684, 4, 41058,
4484, 9, 1046, 9, 23808, 16, 41, 505, 3724, 9,
18073, 18, 2199, 18246, 4194, 8, 13430, 3505, 4, 20,
2433, 61, 32, 551, 88, 1316, 32, 12, 4138, 15557,
47605, 6, 22835, 2591, 939, 4, 242, 10079, 38422, 9235,
6, 10295, 22540, 14819, 8, 3039, 11543, 4, 2, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

"""

Short text sequences make the Document Rotation and Sentence Permutation techniques meaningless. In contrast, the other methods mentioned (Token Masking, Token Deletion and Text Infilling) can be helpful in short and long text sequences.

Like Sequence Permutation, we can remove the oldest sentence for each data entry while adding a new sentence, thus maintaining a context window.

Conclusions

This post has discussed the different strategies for training language models with Sequence Corruption. Although these are the most famous, most models employ only Token Masking.

To summarise, the most effective strategies corrupt the text rather than change its order. However, both approaches can be combined during model training, and in the case of BART, interesting results have been obtained using Text Infilling and Sentence Permutation.

This training approach can be used in encoder or encoder-decoder transformer models. As far as I know, no decoder-only model uses this approach since, in those cases, autoregressive language modeling, also known as causal language modeling, is used. This is because when using an encoder without a self-attention mechanism, as in the case of GPT models, the prediction of each token depends only on the previous tokens and not on the tokens that come after it. In encoder-only models such as BERT, the attention is bidirectional, which allows the prediction of which token should go in which position depending on the previous and subsequent tokens.

A machine learning about token masking. Source: DALL-E 3.

In future posts, we will discuss in depth how causal language modeling works and more advanced sequence corruption techniques.

Happy coding!

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓