将深度学习论文中的数学实现转化为高效的PyTorch代码:SimCLR对比损失

Translate Convert mathematical implementations from deep learning papers into efficient PyTorch code SimCLR contrastive loss.

学习将高级数学公式实现为高性能的PyTorch代码。

Jeswin Thomas在Unsplash上的照片

介绍

加深对深度学习模型和损失函数背后数学原理的理解,以及提高PyTorch技能的最佳方法之一是熟悉自己实现深度学习论文。

书籍和博客文章可以帮助你开始编码和学习机器学习/深度学习的基础知识,但在学习了其中几篇并且在该领域的日常任务中变得熟练之后,你很快就会意识到自己在学习过程中是独自一人,并且会发现大多数在线资源都很无聊且过于肤浅。然而,我相信,如果你能及时研究并理解新发布的深度学习论文中所需的数学部分(不一定是作者理论背后的全部数学证明),并且你是一个能够将它们实现为高效代码的合格编码者,没有什么能阻止你在该领域保持最新并学习新思想。

对比损失(Contrastive Loss)实现

我将介绍我在实现深度学习论文中的数学部分时所遵循的常规步骤,以一个不那么简单的例子为例:SimCLR论文中的对比损失

这是损失的数学形式:

来自SimCLR论文的对比(NT-Xent)损失 | 来自https://arxiv.org/pdf/2002.05709.pdf

我同意这个公式的外观可能会让人望而生畏!你可能会认为在GitHub上一定有很多现成的PyTorch实现,所以让我们使用它们吧:) 是的,你是对的。在线上有数十种实现。然而,我认为这是一个练习这一技能的好例子,并且可以作为一个良好的起点。

将数学实现为代码的步骤

我在将论文中的数学部分实现为高效的PyTorch代码时的常规步骤如下:

  1. 理解数学,并用简单的术语解释它
  2. 使用简单的Python“for”循环实现初始版本,暂时不使用复杂的矩阵乘法
  3. 将你的代码转换为高效且适用于矩阵的PyTorch代码

好的,让我们直接进入第一步。

第一步:理解数学并用简单的术语解释它

我假设你对线性代数有基础知识,并熟悉数学符号。如果你不了解,你可以使用这个工具,通过绘制符号来了解每个符号代表的含义及其在数学中的作用。你也可以查看这个很棒的维基百科页面,其中描述了大多数符号。在这些机会中,你会学到新东西,通过搜索和阅读需要的内容来学习。我认为这是一种更高效的学习方式,而不是从头开始阅读数学教科书,并在几天后放弃它。

回到我们的事情上来。如上述公式前面的段落添加了更多的背景信息,在SimCLR学习策略中,你从N个图像开始,将每个图像转换2次以获得这些图像的增强视图(现在有2*N个图像)。然后,你将这2 * N个图像通过模型,获取每个图像的嵌入向量。现在,你希望使同一图像的2个增强视图(一个正样本对)在嵌入空间中更接近(并对所有其他正样本对执行相同操作)。衡量两个向量相似(接近,同一方向)的一种方法是使用余弦相似度,其定义为sim(u, v)(在上面的图像中查找定义)。

简单来说,这个公式描述的是对于我们批次中的每个项目,即图像的一个增强视图的嵌入(记住:批次包含不同图像的增强视图的所有嵌入→如果从N个图像开始,则批次的大小为2*N),我们首先找到该图像的另一个增强视图的嵌入以构成正对。然后,我们计算这两个嵌入的余弦相似度并对其进行指数化(公式中的分子)。然后,我们计算我们从开始的第一个嵌入向量构建的所有其他对的余弦相似度的指数化(除了与自身的对,这就是公式中的1[k!=i]的含义),并将它们相加以构建分母。现在,我们可以将分子除以分母并取其自然对数并翻转符号!现在,我们有了批次中第一个项目的损失。我们只需要对批次中的所有其他项目重复相同的过程,然后取平均值,以便能够调用PyTorch的.backward()方法来计算梯度。

第二步:使用简单的Python代码实现,使用简单的“for”循环!

简单的Python实现,使用慢速的“for”循环

让我们来看看代码。假设我们有两个图像:A和B。变量aug_views_1保存这两个图像(A1和B1)的一个增强视图的嵌入(每个嵌入大小为3),aug_views_2也是一样(A2和B2);因此,两个矩阵的第一个项目与图像A相关,第二个项目与图像B相关。我们将这两个矩阵连接到投影矩阵中(其中包含4个向量:A1、B1、A2、B2)。

为了保持投影矩阵中向量之间的关系,我们定义了pos_pairs字典来存储在连接矩阵中相关的两个项目。(很快我会解释F.normalize()的作用!)

如您在代码的下面几行中所见,我在一个for循环中遍历投影矩阵中的项目,找到使用我们的字典相关的向量,然后计算余弦相似度。您可能想知道为什么在计算余弦相似度的那一行中不除以向量的大小,就像余弦相似度公式建议的那样。关键是,在开始循环之前,我使用F.normalize函数来将投影矩阵中的所有向量归一化为大小为1。因此,在计算余弦相似度的那一行中,不需要除以大小。

在构建分子之后,我找到批次中向量的所有其他索引(除了相同的索引i),以计算构成分母的余弦相似度。最后,我通过将分子除以分母并应用对数函数并翻转符号来计算损失。确保通过调整代码来了解每行发生的情况。

第三步:将其转换为高效的矩阵友好型PyTorch代码

前面Python实现的问题是速度太慢,无法在我们的训练流程中使用;我们需要摆脱慢速的“for”循环,并将其转换为矩阵乘法和数组操作,以利用并行化能力。

PyTorch实现

让我们看看这段代码中发生了什么。这次,我引入了labels_1和labels_2张量来编码这些图像属于的任意类别,因为我们需要一种方式来编码A1、A2和B1、B2图像之间的关系。选择标签0和1(如我所做)或者说5和8并不重要。

在连接嵌入和标签后,我们首先创建一个sim_matrix,其中包含所有可能的配对的余弦相似度。

How the sim_matrix looks like: the green cells contain our positive pairs, the orange cells are the pairs which need to be ignored in the denominator | Visualization by the author

上面的可视化图表是你理解代码工作原理以及我们为什么要按照特定步骤进行的全部所需。考虑到sim_matrix的第一行,我们可以计算批次中第一个项目(A1)的损失,计算方法如下:我们需要将A1A2(指数化)除以A1B1、A1A2和A1B2(每个项先指数化)的总和,并将结果保存在存储所有损失的张量的第一个项目中。因此,我们首先需要创建一个掩码来找到上面可视化图表中的绿色单元格。代码中定义变量mask的两行正是这样做的。分子通过将我们的sim_matrix与刚刚创建的掩码相乘,然后对每一行的项进行求和来计算。在掩码处理后,每一行将只有一个非零项,即绿色单元格。为了计算分母,我们需要对每一行求和,忽略对角线上的橙色单元格。为此,我们将使用PyTorch张量的.diag()方法。其余部分都是不言自明的!

奖励:使用AI助手(ChatGPT,Copilot等)来实现公式

我们有很多强大的工具可以帮助我们理解和实现深度学习论文中的数学。例如,你可以在给出论文中的公式后,询问ChatGPT(或其他类似的工具)来实现PyTorch代码。在我的经验中,如果你能自己以某种方式进入pythonic-for-loop实现步骤,ChatGPT可以提供最有帮助并且在较少尝试和错误的情况下提供最佳的最终答案。将这种朴素的实现给ChatGPT,并要求将其转换为仅使用矩阵乘法和张量操作的高效PyTorch代码,你会感到惊讶的:)

进一步阅读

我鼓励你查看以下两个关于同一思想的精彩实现,了解如何扩展此实现以考虑更复杂的情况,例如在监督对比学习设置中。

  1. Guillaume Erhard的监督对比损失
  2. Yonglong Tian的SupContrast

关于我

我是Moein Shariatnia,一名机器学习开发人员兼医学生,专注于将深度学习解决方案应用于医学成像应用。我的研究主要是探索在各种情况下深度模型的泛化能力。随时通过电子邮件、Twitter或LinkedIn与我联系。