用10亿训练对训练一个句子嵌入模型
训练句子嵌入模型,使用10亿数据
句子嵌入是一种将句子映射到实数向量的方法。理想情况下,这些向量应该能够捕捉句子的语义并且具有高度的通用性。这样的表示可以用于许多下游应用,如聚类、文本挖掘或问题回答。
我们作为“使用1B个训练对训练最佳句子嵌入模型”的项目的一部分开发了最先进的句子嵌入模型。该项目在由Hugging Face组织的使用JAX/Flax进行NLP和CV的社区周活动中进行。我们从高效的硬件基础设施中受益,以支持项目的运行:7个TPU v3-8,以及来自Google的Flax、JAX和Cloud团队成员关于高效深度学习框架的指导!
训练方法
模型
与单词不同,我们无法定义一个有限的句子集合。因此,句子嵌入方法通过组合单词来计算最终表示。例如,SentenceBert模型(Reimers和Gurevych,2019)使用了Transformer,这是许多NLP应用的基石,然后在上下文化单词向量上进行池化操作。(参见下图。)
多负排名损失
通常使用自监督目标来学习组合模块的参数。对于该项目,我们使用了下图所示的对比训练方法。我们构建了一个包含句子对(ai,pi)的数据集,其中来自该对的句子具有相似的含义。例如,我们考虑了诸如(查询,答案-段落)、(问题,重复问题)、(论文标题,引用论文标题)等对。然后,我们训练模型将句子对(ai,pi)映射到接近的向量,同时将不匹配的句子对(ai,pj),i≠j映射到嵌入空间中的远离向量。这种训练方法也被称为批内负样本训练、InfoNCE或NTXentLoss。
形式上,给定一批训练样本,模型优化以下损失函数:
− 1 n ∑ i = 1 n e x p ( s i m ( a i , p i ) ) ∑ j e x p ( s i m ( a i , p j ) )
下图是一个说明性的示例。模型首先将批中每个句子从每个句子对中进行嵌入。然后,我们计算每个可能的句子对(ai,pj)之间的相似性矩阵。然后,我们将相似性矩阵与原始句子对的真值进行比较。最后,我们使用交叉熵损失进行比较。
直观地,模型应该将句子“柏林有多少人口?”和“柏林有大约350万人口”之间的相似性评分设为较高,将其他负面答案如“法国的首都是巴黎”之间的相似性评分设为较低,如下图所示。
在损失方程中,sim
表示(a,p)之间的相似性函数。相似性函数可以是余弦相似性或点积运算符。这两种方法都有其优缺点,如下所述(Thakur等人,2021;Bachrach等人,2014):
在实践中,我们使用了一个经过缩放的相似性,因为得分差异往往太小,并应用了一个缩放因子C,使得sim\_scaled(a, b) = C * sim(a, b),通常C = 20(Henderson等人,2020;Radford等人,2021)。
通过更好的批次提高质量
在我们的方法中,我们建立样本对 ( a i , p i ) (a_i , p_i) ( a i , p i ) 的批次。我们将批次中的所有其他样本 ( a i , p j ) , i ≠ j (a_i , p_j), i \neq j ( a i , p j ) , i = j ,视为负样本对。因此,批次的组成是关键的训练方面。鉴于该领域的文献,我们主要关注了批次的三个主要方面。
1. 大小重要
在对比学习中,较大的批次大小与更好的性能是同义的。如 Qu 等人(2021)的图所示,较大的批次大小会增加结果。
2. 难负样本
在同一图中,我们观察到包含难负样本也会提高性能。难负样本是很难与 p i p_i p i 区分的样本 p j p_j p j 。在我们的例子中,它可能是对「法国的首都是什么?」和「美国的首都是什么?」这两个具有相似语义内容且需要准确理解完整句子才能正确回答的句对。相反,对于「法国的首都是什么?」和「有多少部星球大战电影?」这两个样本,它们较容易区分,因为它们不涉及相同的主题。
3. 跨数据集批次
我们连接了多个数据集来训练我们的模型。我们构建了一个大批次,并从同一批次数据集中收集样本,以限制主题分布并倾向于难负样本。但是,我们还在批次中混合了至少两个数据集,以学习主题之间的全局结构,而不仅仅是主题内的局部结构。
培训基础设施和数据
如前所述,数据量和批次大小直接影响模型的性能。作为项目的一部分,我们受益于高效的硬件基础设施。我们在由 Google 开发的 TPUs 上训练了我们的模型,这些 TPUs 对于矩阵乘法非常高效。TPUs具有一些硬件特性,可能需要一些特定的代码实现。
此外,我们在一个大语料库上训练了模型,最多连接了10亿个句子对!所有使用的数据集在模型卡片中都有详细说明。
结论
在挑战期间,您可以在我们的HuggingFace仓库中找到我们创建的所有模型和数据集。我们训练了20个通用的句子转换模型,如 Mini-LM(Wang等人,2020)、RoBERTa(liu等人,2019)、DistilBERT(Sanh等人,2020)和MPNet(Song等人,2020)。我们的模型在多个通用句子相似性评估任务上取得了SOTA的成绩。我们还分享了8个专门用于问答、句子相似性和性别评估的数据集。
通用句子嵌入可以用于许多应用。我们建立了一个 Spaces 演示来展示几个应用:
- 「句子相似性」模块比较主要文本与您选择的其他文本之间的相似性。在背景中,演示提取每个文本的嵌入,并使用余弦相似度计算源句子与其他句子之间的相似性。
- 「非对称问答」将给定查询的答案可能性与您选择的答案候选进行比较。
- 「搜索/聚类」从查询中返回附近的答案。例如,如果您输入「Python」,它将使用点积距离检索最接近的句子。
- 「性别偏见评估」通过对句子进行随机抽样,报告训练集中固有的性别偏见。给定一个没有提及目标职业性别的锚定文本,以及具有性别代词的两个命题,我们比较模型是否将更高的相似度分配给给定的命题,从而评估它们更倾向于支持特定性别的比例。
使用 JAX/Flax 进行 NLP 和 CV 的社区周是一次紧张而非常有收获的经历!Google 的 Flax、JAX 和 Cloud 以及 Hugging Face 团队成员的指导和支持帮助我们所有人都学到了很多。我们希望所有项目都能像我们一样开心。如果您有任何问题或建议,请随时与我们联系!