了解BigBird的块稀疏注意力

BigBird的块稀疏注意力

介绍

基于Transformer的模型在许多自然语言处理任务中表现出了极高的效用。然而,基于Transformer的模型的一个主要限制是其O(n^2)的时间和内存复杂度(其中n是序列长度)。因此,将基于Transformer的模型应用于长度超过512的长序列是非常耗费计算资源的。近期的一些论文,如LongformerPerformerReformerClustered attention,尝试通过近似全连接注意力矩阵来解决这个问题。如果您对这些模型不熟悉,可以查看🤗的最新博客文章。

BigBird(在论文中提出)就是解决这个问题的最近的一种模型之一。与BERT的注意力机制不同,BigBird采用了分块稀疏注意力,可以以较低的计算成本处理长度为4096的序列,相比之下,BERT的计算成本要高得多。它在涉及到非常长的序列的各种任务中,如长文档摘要、长上下文的问答等,已经达到了最先进水平。

现在,🤗Transformers中提供了BigBird RoBERTa-like模型。本文的目标是让读者深入了解BigBird的实现,并帮助使用🤗Transformers轻松使用BigBird。但是,在深入了解之前,重要的一点是要记住,BigBird的注意力是对BERT的全连接注意力的一种近似,因此并不致力于比BERT的全连接注意力更好,而是更高效。它简单地允许将基于Transformer的模型应用于更长的序列,因为BERT的二次内存要求很快变得无法承受。简而言之,如果我们有无穷的计算和无穷的时间,BERT的注意力机制将优于分块稀疏注意力(我们将在本文中讨论)。

如果您想知道为什么在处理更长的序列时我们需要更多的计算资源,这篇博客文章正适合您!


在使用标准的BERT-like注意力机制时,人们可能会有一些主要问题:

  • 所有标记是否真的需要与所有其他标记关联?
  • 为什么不仅计算与重要标记的关联性?
  • 如何确定哪些标记是重要的?
  • 如何以非常高效的方式关注只有少数标记?

在本博客文章中,我们将尝试回答这些问题。

应该关注哪些标记?

我们将通过一个实际的例子来说明注意力机制的工作原理,考虑句子“BigBird现在可以在HuggingFace中用于抽取式问答”。在类似BERT的注意力机制中,每个词汇只会关注所有其他标记。数学上来说,这意味着每个查询标记query-token ∈ { BigBird , is , now , available , in , HuggingFace , for , extractive , question , answering },都会关注完整的键标记列表key-tokens = [ BigBird , is , now , available , in , HuggingFace , for , extractive , question , answering ]。

让我们通过编写一些伪代码来思考一个合理的键令牌选择,即查询令牌实际上只应该参与。我们将假设查询的令牌available并构建一个合理的键令牌列表来参与。

>>> # 让我们考虑以下句子作为示例
>>> example = ['BigBird', 'is', 'now', 'available', 'in', 'HuggingFace', 'for', 'extractive', 'question', 'answering']

>>> # 进一步假设我们正在尝试理解“available”的表示,即
>>> query_token = 'available'

>>> # 我们将初始化一个空的“set”并随着我们在本节中的进展填充我们感兴趣的令牌。
>>> key_tokens = [] # => 当前'available'令牌没有要参与的内容

附近的令牌应该很重要,因为在一个句子(单词序列)中,当前单词高度依赖于相邻的过去和未来的令牌。这个直觉是“滑动注意力”的概念背后的想法。

>>> # 考虑`window_size = 3`,我们将考虑'available'左边1个令牌和右边1个令牌
>>> # 左边令牌:'now' ; 右边令牌:'in'
>>> sliding_tokens = ["now", "available", "in"]

>>> # 让我们使用上面的令牌更新我们的集合
>>> key_tokens.append(sliding_tokens)

长程依赖:对于某些任务来说,捕捉令牌之间的长程关系至关重要。例如,在“问答”任务中,模型需要将上下文的每个令牌与整个问题进行比较,以确定哪个部分的上下文对于正确答案有用。如果大部分上下文令牌只参与其他上下文令牌,而不是问题令牌,模型就很难从不重要的上下文令牌中过滤出重要的上下文令牌。

现在,BigBird提出了两种方式,可以在保持计算效率的同时允许长期注意依赖。

  • 全局令牌:引入一些令牌,这些令牌将参与到每个令牌,并且每个令牌都会参与到这些令牌。例如:”HuggingFace正在为易于处理NLP构建出色的库”。现在,假设“building”被定义为一个全局令牌,并且模型需要了解“NLP”和“HuggingFace”之间的关系以完成某个任务(注意:这两个令牌位于两个极端)。现在,让“building”全局参与到所有其他令牌可能有助于模型将“NLP”与“HuggingFace”关联起来。
>>> # 假设第一个和最后一个令牌是`global`,那么
>>> global_tokens = ["BigBird", "answering"]

>>> # 将全局令牌填充到我们的键令牌集合中
>>> key_tokens.append(global_tokens)
  • 随机令牌:随机选择一些令牌,这些令牌通过传递给其他令牌来传递信息,从而减少了从一个令牌到另一个令牌的信息传递成本。
>>> # 现在我们可以从我们的示例句子中随机选择`r`个令牌
>>> # 假设选择'is',假设`r=1`
>>> random_tokens = ["is"] # 注意:它是完全随机选择的,所以它也可以是其他什么。

>>> # 将随机令牌填充到我们的集合中
>>> key_tokens.append(random_tokens)

>>> # 是时候看看我们的`key_tokens`列表中有哪些令牌了
>>> key_tokens
{'now', 'is', 'in', 'answering', 'available', 'BigBird'}

# 现在,“available”(我们在第一步中选择的查询)只会参与这些令牌,而不是参与整个序列

这样,查询令牌只参与到所有可能令牌的一个子集中,同时产生了完全关注的很好近似。对于所有其他查询的令牌,都将使用相同的方法。但请记住,这里的整体目标是尽可能有效地近似BERT的完全关注。对于现代硬件(如GPU),像BERT一样使每个查询的令牌参与到所有键令牌的计算非常有效,可以通过一系列矩阵乘法来完成。然而,滑动、全局和随机注意力的组合似乎意味着稀疏矩阵乘法,这在现代硬件上的高效实现更加困难。 BigBird的主要贡献之一是提出了一种允许高效计算滑动、全局和随机注意力的块稀疏注意力机制。让我们来看一下吧!

理解图中全局、滑动、随机键的需求

首先,让我们通过图形更好地理解全局 (global)滑动 (sliding)随机 (random) 注意力,并试着理解这三种注意力机制的组合如何产生对标准Bert-like 注意力的很好近似。

上图分别显示了 全局 (global)(左)、滑动 (sliding)(中)和随机 (random)(右)连接作为图形。每个节点对应一个令牌,每条线表示一个注意力分数。如果两个令牌之间没有连接,那么假设注意力分数为0。

BigBird块稀疏注意力是滑动、全局和随机连接(总共10个连接)的组合,如左侧的gif所示。而普通注意力的图形(右侧)将具有所有15个连接(注意:总共有6个节点)。您可以将普通注意力简单地理解为所有令牌都以全局方式参与注意力 1 {}^1 1 。

普通注意力:模型可以在单个层中直接从一个令牌传递信息到另一个令牌,因为每个令牌都会查询其他每个令牌,并且每个令牌都会被其他每个令牌注意到。让我们考虑一个类似于上图所示的示例。如果模型需要将“going”与“now”关联起来,它可以简单地在一个单独的层中完成,因为两个令牌之间有直接的连接。

块稀疏注意力:如果模型需要在两个节点(或令牌)之间共享信息,则对于某些令牌,信息必须在路径中经过其他各个节点传播;因为不是所有节点都在单个层中直接连接。例如,假设模型需要将“going”与“now”关联起来,那么如果只有滑动注意力存在,则这两个令牌之间的信息流动由路径定义:going -> am -> i -> now (即它必须经过2个其他令牌)。因此,我们可能需要多个层来捕捉序列的全部信息。普通注意力可以在一个单独的层中捕捉到这一点。在极端情况下,这可能意味着需要与输入令牌一样多的层。然而,如果我们引入一些全局令牌,信息可以通过路径传播:going -> i -> now (这条路径更短)。如果我们还引入随机连接,信息可以通过路径传播:going -> am -> now 。借助随机连接和全局连接的帮助,信息可以在令牌之间非常快速地传递(只需几个层)。

如果有许多全局令牌,那么我们可能不需要随机连接,因为信息可以通过多个短路径传播。这就是在使用BigBird的一种变体ETC时保持num_random_tokens = 0的想法(稍后的部分将更详细介绍)。

1 {}^1 1 在这些图形中,我们假设注意力矩阵是对称的,即 A i j = A j i \mathbf{A}_{ij} = \mathbf{A}_{ji} A i j ​ = A j i ​ ,因为在图中如果某个令牌 A 参与注意力 B ,那么 B 也会参与 A 。您可以从下一节中显示的注意力矩阵图中看到,这个假设对于BigBird中的大多数令牌成立。

original_full表示BERT的注意力,而block_sparse表示BigBird的注意力。想知道block_size是什么吗?我们将在后面的部分中介绍。现在,为了简单起见,将其视为1。

BigBird块稀疏注意力

BigBird块稀疏注意力只是我们上面讨论的高效实现。每个标记不再关注所有其他标记,而是关注一些全局标记滑动标记随机标记。作者为多个查询组件分别硬编码了注意力矩阵,并使用了一种聪明的技巧来加速GPU和TPU上的训练/推断。

注意:在顶部,我们有两个额外的句子。正如您可以注意到的那样,每个标记在两个句子中只是相互交换了一个位置。这就是滑动注意力的实现方式。当将q[i]k[i,0:3]相乘时,我们将得到q[i]的滑动注意力分数(其中i是序列中元素的索引)。

您可以在此处找到block_sparse注意力的实际实现。现在这可能看起来非常吓人😨😨。但是这篇文章肯定会帮助您更好地理解代码。

全局注意力

对于全局注意力,每个查询只是简单地关注序列中的所有其他标记,并且被每个其他标记所关注。让我们假设Vasudev(第一个标记)和them(最后一个标记)是全局的(在上图中)。您可以看到这些标记直接连接到所有其他标记(蓝色框)。

# 伪代码

Q -> 查询矩阵(seq_length,head_dim)
K -> 键矩阵(seq_length,head_dim)

# 第一个和最后一个标记关注所有其他标记
Q[0] x [K[0],K[1],K[2],......,K[n-1]]
Q[n-1] x [K[0],K[1],K[2],......,K[n-1]]

# 第一个和最后一个标记被所有其他标记关注
K[0] x [Q[0],Q[1],Q[2],......,Q[n-1]]
K[n-1] x [Q[0],Q[1],Q[2],......,Q[n-1]]

滑动注意力

将关键标记序列复制两次,其中一个副本中的每个元素向右移动一个位置,另一个副本中的每个元素向左移动一个位置。现在,如果我们将查询序列向量乘以这3个序列向量,我们将覆盖所有滑动标记。计算复杂度仅为O(3xn) = O(n)。参考上图,橙色框表示滑动注意力。您可以看到图上方有3个序列,其中2个序列向右移动一个标记(一个向左,一个向右)。

# 我们想要做的事情
Q[i] x [K[i-1],K[i],K[i+1]],其中i = 1:-1

# 代码中的高效实现(假设点积乘法 👇)
[Q[0],Q[1],Q[2],......,Q[n-2],Q[n-1]] x [K[1],K[2],K[3],......,K[n-1],K[0]]
[Q[0],Q[1],Q[2],......,Q[n-1]] x [K[n-1],K[0],K[1],......,K[n-2]]
[Q[0],Q[1],Q[2],......,Q[n-1]] x [K[0],K[1],K[2],......,K[n-1]]

# 每个序列只与3个序列相乘,以保持`window_size = 3`。
# 一些计算可能会缺失,这只是一个大致的想法。

随机注意力

随机注意力确保每个查询标记也会关注几个随机标记。对于实际实现,这意味着模型随机收集一些标记,并计算它们的注意力分数。

# r1, r2, r 是一些随机索引;注意:r1, r2, r3对于每一行都不同 👇
Q[1] x [Q[r1],Q[r2],......,Q[r]]
.
.
.
Q[n-2] x [Q[r1],Q[r2],......,Q[r]]

# 不考虑第0个和第(n-1)个标记,因为它们已经是全局的

注意:当前的实现将序列进一步分成块,每个符号都是相对于块而不是标记进行定义的。让我们在下一节中详细讨论这个问题。

实现

回顾:在常规的BERT注意力中,一个标记序列,即 X = x 1 , x 2 , . . . . , x n,通过一个密集层投影到 Q , K , V 和注意力得分 Z 的计算为 Z = S o f t m a x ( Q K T )。对于BigBird块稀疏注意力,使用相同的算法,但仅使用一些选定的查询和键向量。

让我们来看看BigBird块稀疏注意力是如何实现的。首先,假设 b , r , s , g 分别表示 block_size , num_random_blocks , num_sliding_blocks , num_global_blocks。在可视化上,我们可以用 b = 4 , r = 1 , g = 2 , s = 3 , d = 5 来表示BigBird块稀疏注意力的组件,如下图所示:

分别计算 q 1 , q 2 , q 3 的注意力得分:n − 2 , q n − 1 , q n {q}_{1}, {q}_{2}, {q}_{3:n-2}, {q}_{n-1}, {q}_{n},具体计算如下:


表示 q 1 的注意力得分 a 1 a_1 a 1 ​​,其中 a 1 = S o f t m a x ( q 1 ∗ K T ),实际上就是第一个块中所有标记与序列中所有其他标记之间的注意力得分。

q 1 表示第一个块,g i 表示第 i 个块。我们只是在 q 1 和 g 之间执行常规的注意力操作(即对所有键执行操作)。


为了计算第二个块中标记的注意力得分,我们收集了前三个块、最后一个块和第五个块。然后我们可以计算 a 2 = S o f t m a x ( q 2 ∗ c o n c a t ( k 1 , k 2 , k 3 , k 5 , k 7 )。

我用 g , r , s 来表示标记,以明确表示它们的性质(即显示全局、随机、滑动的标记),否则它们只是 k。


为了计算 q 3 : n − 2 的注意力得分,我们将收集全局、滑动、随机的键,并在 q 3 : n − 2 和收集到的键之间执行常规的注意力操作。注意,滑动键使用之前在滑动注意力部分讨论过的特殊移动技巧进行收集。


计算前两个块(即 q n − 1 {q}_{n-1} q n − 1 ​ )中令牌的注意力分数时,我们收集第一个块、最后三个块和第三个块。然后我们可以应用公式 a n − 1 = S o f t m a x ( q n − 1 ∗ c o n c a t ( k 1 , k 3 , k 5 , k 6 , k 7 ) ) {a}_{n-1} = Softmax({q}_{n-1} * concat(k_1, k_3, k_5, k_6, k_7)) a n − 1 ​ = S o f t m a x ( q n − 1 ​ ∗ c o n c a t ( k 1 ​ , k 3 ​ , k 5 ​ , k 6 ​ , k 7 ​ ) ) 。这与我们为 q 2 q_2 q 2 ​ 所做的非常相似。


q n \mathbf{q}_{n} q n ​ 的注意力分数由 a n a_n a n ​ 表示,其中 a n = S o f t m a x ( q n ∗ K T ) a_n=Softmax(q_n * K^T) a n ​ = S o f t m a x ( q n ​ ∗ K T ) ,实际上就是最后一个块中所有令牌与序列中其他所有令牌之间的注意力分数。这与我们为 q 1 q_1 q 1 ​ 所做的非常相似。


让我们将上述矩阵合并以获得最终的注意力矩阵。这个注意力矩阵可以用来获得所有令牌的表示。

蓝色 -> 全局块红色 -> 随机块橙色 -> 滑动块 这个注意力矩阵只是用于说明。在前向传播过程中,我们不存储 白色 块,而是直接为每个分离的组件计算加权值矩阵(即每个令牌的表示),如上所述。

现在,我们已经介绍了块稀疏注意力的最困难部分,即其实现。希望您现在对理解实际代码有更好的背景。请随意深入研究,并将代码的每个部分与上述的组件连接起来。

时间和内存复杂度

比较 BERT 注意力和 BigBird 块稀疏注意力的时间和空间复杂度。

点击以展开此段代码以查看计算

BigBird 时间复杂度 = O(w x n + r x n + g x n)
BERT 时间复杂度 = O(n^2)

假设:
    w = 3 x 64
    r = 3 x 64
    g = 2 x 64

当 seqlen = 512
=> **BERT 时间复杂度 = 512^2**

当 seqlen = 1024
=> BERT 时间复杂度 = (2 x 512)^2
=> **BERT 时间复杂度 = 4 x 512^2**

=> BigBird 时间复杂度 = (8 x 64) x (2 x 512)
=> **BigBird 时间复杂度 = 2 x 512^2**

当 seqlen = 4096
=> BERT 时间复杂度 = (8 x 512)^2
=> **BERT 时间复杂度 = 64 x 512^2**

=> BigBird 计算复杂度 = (8 x 64) x (8 x 512)
=> BigBird 计算复杂度 = 8 x (512 x 512)
=> **BigBird 时间复杂度 = 8 x 512^2**

ITC vs ETC

BigBird 模型可以使用两种不同的策略进行训练:ITC(内部变压器构建)和 ETC(扩展变压器构建)。ITC 简单地就是我们上面讨论的内容。在 ETC(扩展变压器构建)中,一些额外的令牌被设为全局令牌,因此它们将关注/被其他所有令牌关注。

ITC需要更少的计算,因为全局令牌很少,同时模型可以捕捉足够的全局信息(还可以利用随机注意力的帮助)。另一方面,ETC对于需要大量全局令牌的任务非常有帮助,例如“问答”,对于这种任务,整个问题应该在上下文中以全局方式出现,才能正确地将上下文与问题相关联。

注意:在大鸟论文中显示,在许多ETC实验中,随机块的数量设置为0是合理的。这是基于我们在图形部分上面的讨论。

下表总结了ITC和ETC:

使用🤗Transformers的BigBird

您可以像使用任何其他🤗模型一样使用BigBirdModel。让我们看一下下面的代码:

from transformers import BigBirdModel

# 从预训练的检查点中加载bigbird
model = BigBirdModel.from_pretrained("google/bigbird-roberta-base")
# 这将使用默认配置初始化模型,即attention_type = “block_sparse” num_random_blocks = 3, block_size = 64。
# 但是您可以自由更改这些参数与任何检查点。这三个参数只会更改每个查询令牌要参与的令牌数量。
model = BigBirdModel.from_pretrained("google/bigbird-roberta-base", num_random_blocks=2, block_size=16)

# 通过将attention_type设置为`original_full`,BigBird将依赖于n^2复杂度的完全注意力。这样,BigBird与BERT相似度达到99.9%。
model = BigBirdModel.from_pretrained("google/bigbird-roberta-base", attention_type="original_full")

🤗Hub上有3个检查点可用(在撰写本文时):bigbird-roberta-basebigbird-roberta-largebigbird-base-trivia-itc。前两个检查点来自使用masked_lm loss预训练BigBirdForPretraining;而最后一个对应于在trivia-qa数据集上微调BigBirdForQuestionAnswering后的检查点。

让我们看一下您可以编写的最简代码(如果您想使用您的PyTorch训练器),使用🤗的BigBird模型来微调您的任务。

# 假设我们的任务是问答为例

from transformers import BigBirdForQuestionAnswering, BigBirdTokenizer
import torch

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

# 使用预训练的权重初始化bigbird模型,并在其顶部随机初始化一个头部
model = BigBirdForQuestionAnswering.from_pretrained("google/bigbird-roberta-base", block_size=64, num_random_blocks=3)
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
model.to(device)

dataset = "torch.utils.data.DataLoader object"
optimizer = "torch.optim object"
epochs = ...

# 非常简单的训练循环
for e in range(epochs):
    for batch in dataset:
        model.train()
        batch = {k: batch[k].to(device) for k in batch}

        # 前向传播
        output = model(**batch)

        # 反向传播
        output["loss"].backward()
        optimizer.step()
        optimizer.zero_grad()

# 将最终权重保存在本地目录中
model.save_pretrained("<YOUR-WEIGHTS-DIR>")

# 将我们的权重推送到🤗Hub
from huggingface_hub import ModelHubMixin
ModelHubMixin.push_to_hub("<YOUR-WEIGHTS-DIR>", model_id="<YOUR-FINETUNED-ID>")

# 使用微调模型进行推理
question = ["你好吗?", "生活怎么样?"]
context = ["<一些包含答案-1的大背景>", "<一些包含答案-2的大背景>"]
batch = tokenizer(question, context, return_tensors="pt")
batch = {k: batch[k].to(device) for k in batch}

model = BigBirdForQuestionAnswering.from_pretrained("<YOUR-FINETUNED-ID>")
model.to(device)
with torch.no_grad():
    start_logits, end_logits = model(**batch).to_tuple()
    # 现在使用您想要的任何策略解码start_logits,end_logits。

# 注意:
# 这是非常简单的代码(如果您想使用原始的PyTorch),只是为了展示如何非常简单地使用BigBird
# 我建议使用🤗Trainer以获得许多功能

在使用Big Bird时,需要牢记以下几点:

  • 序列长度必须是块大小的倍数,即seqlen % block_size = 0。如果批次序列长度不是块大小的倍数,您无需担心,因为🤗Transformers会自动将其<pad>(到大于序列长度的最小块大小倍数)。
  • 目前,HuggingFace版本不支持ETC,因此只有第一个和最后一个块是全局的。
  • 当前的实现不支持num_random_blocks = 0
  • 作者建议,在序列长度<1024时,将attention_type = "original_full"
  • 必须满足以下条件:seq_length > global_token + random_tokens + sliding_tokens + buffer_tokens,其中global_tokens = 2 x block_sizesliding_tokens = 3 x block_sizerandom_tokens = num_random_blocks x block_sizebuffer_tokens = num_random_blocks x block_size。如果未能满足此条件,🤗Transformers将自动切换attention_typeoriginal_full并发出警告。
  • 当将big bird用作解码器(或使用BigBirdForCasualLM)时,attention_type应为original_full。但是您无需担心,如果您忘记这样设置,🤗Transformers将自动将attention_type切换为original_full

接下来是什么?

@patrickvonplaten在如何在trivia-qa数据集上评估BigBirdForQuestionAnswering上制作了一个非常酷的笔记本。随意使用那个笔记本玩转BigBird。

您很快将在库中找到类似BigBird Pegasus的模型,用于长文档摘要💥。

结束语

原始的块稀疏注意力矩阵的实现可以在此找到。您可以在此找到🤗的版本。