Nyström形式:通过Nyström方法在线性时间和内存中近似自注意力

Nyström形式:线性时间和内存中近似自注意力的Nyström方法

简介

Transformer 在自然语言处理和计算机视觉任务中展现了出色的性能。其成功归功于自注意力机制,该机制捕捉了输入中所有标记之间的成对交互。然而,标准自注意力机制的时间和内存复杂度为 O(n^2),这使得在长输入序列上训练变得昂贵。

尼斯特罗姆变换器(Nyströmformer)是众多高效 Transformer 模型之一,它使用 O(n) 的复杂度近似标准自注意力。尼斯特罗姆变换器在各种下游自然语言处理和计算机视觉任务中展现了竞争力,并改进了标准自注意力的效率。本博文旨在为读者概述尼斯特罗姆方法以及如何将其调整为近似自注意力。

矩阵近似的尼斯特罗姆方法

尼斯特罗姆变换器的核心是矩阵近似的尼斯特罗姆方法。它允许我们通过对矩阵的部分行和列进行采样来近似矩阵。让我们考虑一个昂贵的 n × n 矩阵 P,我们无法完全计算它。因此,我们使用尼斯特罗姆方法来近似它。我们首先从 P 中采样 m 行和列。然后,我们可以将采样的行和列排列如下:

将 P 表示为块状矩阵

现在我们有了四个子矩阵:A P、B P、F P、C P,它们的大小分别为 m × m、m × (n – m)、(n – m) × m 和 (n – m) × (n – m)。采样的 m 列包含在 A P 和 F P 中,而采样的 m 行包含在 A P 和 B P 中。因此,A P、B P 和 F P 的条目是已知的,我们将估计 C P。根据尼斯特罗姆方法,C P 的计算公式如下:

C P = F P A P + B P

这里,+ 表示摩尔 – 彭罗斯逆矩阵。因此,矩阵 P 的尼斯特罗姆近似 P ^ 可以表示为:

P 的尼斯特罗姆近似

如第二行所示,P ^ 可以表示为三个矩阵的乘积。为什么这样做将在后面变得清晰。

我们能用尼斯特罗姆方法来近似自注意力吗?

我们的目标是最终近似标准自注意力中的 softmax 矩阵:S = softmax(QK^T / √d)

这里,Q 和 K 分别表示查询和键。按照上述步骤,我们将从 S 中采样 m 行和列,形成四个子矩阵,并获得 S ^:

S 的尼斯特罗姆近似

但是,从 S 中采样一列意味着我们从每行中选择一个元素。回想一下 S 的计算过程:最后的操作是逐行进行 softmax。要找到一行中的单个条目,我们必须访问所有其他条目(用于 softmax 中的分母)。因此,采样一列需要我们知道矩阵中的所有其他列。因此,我们无法直接应用尼斯特罗姆方法来近似 softmax 矩阵。

如何将Nystrom方法应用于近似自注意力?

作者提出,不再从S S S中进行采样,而是从查询和键中采样地标点(或Nystrom点)。我们分别将查询地标点和键地标点表示为Q ~和K ~。Q ~和K ~可以用于构建与S S S的Nystrom近似中相应的三个矩阵。我们定义以下矩阵:

F ~ = softmax(QK ~ T / sqrt(d))

A ~ = softmax(Q ~ K ~ T / sqrt(d))

B ~ = softmax(Q ~ KT / sqrt(d))

F ~(n×m)、A ~(m×m)和B ~(m×n)分别是矩阵的大小。我们用新定义的三个矩阵替换S S S的Nystrom近似中的三个矩阵,从而得到一种替代的Nystrom近似:

S ^ = F ~ A ~ B ~ = softmax(QK ~ T / sqrt(d)) softmax(Q ~ K ~ T / sqrt(d)) + softmax(Q ~ KT / sqrt(d))

这是自注意机制中softmax矩阵的Nystrom近似。我们将该矩阵与值矩阵(V V V)相乘,以获得自注意力的线性近似。请注意,我们从未计算过QK T的乘积,避免了O(n^2)的复杂度。

我们如何选择地标点?

作者建议不再从Q Q Q和K K K中随机采样m m m行,而是使用分段均值构建Q ~和K ~。在这个过程中,n n n个标记被分为m m m个段,然后计算每个段的均值。理想情况下,m m m远小于n n n。根据论文中的实验,即使对于较长的序列长度(n = 4096或8192),选择只有32或64个地标点的效果与标准自注意力和其他高效的注意力机制相当。

论文总结了整个算法,如下图所示:

使用Nyström方法的高效自注意力

上面的三个橙色矩阵对应于使用关键点和查询关键点构建的三个矩阵。同时,注意到有一个DConv框。这对应于使用一维深度卷积将值添加到跳过连接中。

Nyströmformer是如何实现的?

Nyströmformer的原始实现可以在这里找到,HuggingFace的实现可以在这里找到。让我们看一下HuggingFace实现中的一些代码行(添加了一些注释)。为简单起见,一些细节,如归一化、注意力掩码和深度卷积,被省略。

key_layer = self.transpose_for_scores(self.key(hidden_states)) # K
value_layer = self.transpose_for_scores(self.value(hidden_states)) # V
query_layer = self.transpose_for_scores(mixed_query_layer) # Q

q_landmarks = query_layer.reshape(
    -1,
    self.num_attention_heads,
    self.num_landmarks,
    self.seq_len // self.num_landmarks,
    self.attention_head_size,
).mean(dim=-2) # \tilde{Q}

k_landmarks = key_layer.reshape(
    -1,
    self.num_attention_heads,
    self.num_landmarks,
    self.seq_len // self.num_landmarks,
    self.attention_head_size,
).mean(dim=-2) # \tilde{K}

kernel_1 = torch.nn.functional.softmax(torch.matmul(query_layer, k_landmarks.transpose(-1, -2)), dim=-1) # \tilde{F}
kernel_2 = torch.nn.functional.softmax(torch.matmul(q_landmarks, k_landmarks.transpose(-1, -2)), dim=-1) # \tilde{A} before pseudo-inverse

attention_scores = torch.matmul(q_landmarks, key_layer.transpose(-1, -2)) # \tilde{B} before softmax

kernel_3 = nn.functional.softmax(attention_scores, dim=-1) # \tilde{B}
attention_probs = torch.matmul(kernel_1, self.iterative_inv(kernel_2)) # \tilde{F} * \tilde{A}
new_value_layer = torch.matmul(kernel_3, value_layer) # \tilde{B} * V
context_layer = torch.matmul(attention_probs, new_value_layer) # \tilde{F} * \tilde{A} * \tilde{B} * V

如何使用HuggingFace的Nyströmformer

可以在HuggingFace上使用Nyströmformer进行遮蔽语言建模(MLM)。目前有4个检查点,对应于不同的序列长度:nystromformer-512nystromformer-1024nystromformer-2048nystromformer-4096。可以使用NystromformerConfig中的num_landmarks参数来控制地标点的数量m。让我们来看一个使用Nyströmformer进行MLM的最简示例:

from transformers import AutoTokenizer, NystromformerForMaskedLM
import torch

tokenizer = AutoTokenizer.from_pretrained("uw-madison/nystromformer-512")
model = NystromformerForMaskedLM.from_pretrained("uw-madison/nystromformer-512")

inputs = tokenizer("Paris is the [MASK] of France.", return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

# retrieve index of [MASK]
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]

predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
tokenizer.decode(predicted_token_id)

另外,我们还可以使用pipeline API(它会处理所有复杂性):

from transformers import pipeline
unmasker = pipeline('fill-mask', model='uw-madison/nystromformer-512')
unmasker("Paris is the [MASK] of France.")

结论

Nyströmformer为标准自注意力机制提供了高效的近似方法,同时在性能上超越了其他线性自注意力方案。在本博文中,我们简要介绍了Nyström方法的概述以及如何利用它进行自注意力。对于有兴趣部署或微调Nyströmformer用于下游任务的读者,可以在这里找到HuggingFace的文档。