使用JAX加速我们的研究

'使用JAX加速研究'

DeepMind工程师通过构建工具、扩展算法以及创建具有挑战性的虚拟和物理世界来加速我们的研究,用于训练和测试人工智能(AI)系统。作为这项工作的一部分,我们不断评估新的机器学习库和框架。

最近,我们发现越来越多的项目受益于JAX,这是Google研究团队开发的机器学习框架。JAX与我们的工程哲学相契合,并且在过去一年中已经被我们的研究社区广泛采用。在这里,我们分享了与JAX合作的经验,概述了为支持全球研究人员构建的生态系统,并说明了为什么我们认为它对我们的AI研究有用。

为什么选择JAX?

JAX是一个专为高性能数值计算而设计的Python库,特别适用于机器学习研究。它的数值函数API基于NumPy,一个用于科学计算的函数集合。Python和NumPy都被广泛使用且熟悉,使得JAX简单、灵活且易于采用。

除了NumPy API,JAX还包括一个可扩展的函数转换系统,用于支持机器学习研究,其中包括:

  • 微分:基于梯度的优化是机器学习的基础。JAX原生支持任意数值函数的前向和反向模式自动微分,通过函数转换(如grad、hessian、jacfwd和jacrev)实现。
  • 向量化:在机器学习研究中,我们经常对大量数据应用单个函数,例如对批次计算损失或计算不同ially private learning的每个示例的梯度。JAX提供了通过vmap转换实现的自动向量化,简化了这种形式的编程。例如,研究人员在实现新算法时不需要考虑批处理。JAX还通过相关的pmap转换支持大规模数据并行处理,优雅地分布数据,以适应单个加速器内存无法容纳的大数据。
  • JIT编译:使用XLA可以即时(JIT)编译和执行JAX程序,适用于GPU和Cloud TPU加速器。JIT编译与JAX的NumPy一致的API结合,使得没有高性能计算经验的研究人员可以轻松扩展到一个或多个加速器。

我们发现JAX使得快速尝试新算法和架构成为可能,并成为我们最近发表的许多论文的基础。欲了解更多信息,请考虑参加我们在NeurIPS虚拟会议上举办的JAX圆桌会议,时间为格林威治标准时间12月9日晚上7:00。

DeepMind的JAX应用

支持最先进的AI研究意味着在快速原型设计和快速迭代之间取得平衡,同时能够以传统生产系统所特有的规模部署实验。这些项目的挑战在于研究领域的快速演变和难以预测性。在任何时刻,一项新的研究突破可能会改变整个团队的发展轨迹和需求。在这个不断变化的环境中,我们工程团队的核心责任是确保从一个研究项目中学到的经验和编写的代码能够有效地在下一个项目中复用。

已证明成功的一种方法是模块化:我们将每个研究项目中开发的最重要和关键的构建模块提取出来,形成经过充分测试和高效的组件。这使得研究人员可以专注于他们的研究,同时也从我们的核心库中实现的算法组件的代码复用、错误修复和性能改进中受益。我们还发现,确保每个库具有明确定义的范围,并确保它们互操作但独立是非常重要的。增量式的接受,即能够选择功能而不受限制,对于为研究人员提供最大的灵活性并始终支持他们选择适合工作的正确工具至关重要。

我们在开发JAX生态系统时考虑的其他因素包括确保它在可能的情况下与我们现有的TensorFlow库(如Sonnet和TRFL)的设计保持一致。我们还致力于构建与底层数学尽可能相匹配的组件,以便自描述并尽量减少从论文到代码的跳跃。最后,我们选择开源我们的库,以促进研究成果的共享,并鼓励更广泛的社区探索JAX生态系统。

我们的生态系统今天

Haiku ‍

可组合函数转换的JAX编程模型可以使处理有状态对象变得复杂,例如具有可训练参数的神经网络。Haiku是一个神经网络库,它允许用户使用熟悉的面向对象的编程模型,同时利用JAX纯函数编程范式的强大和简洁。

Haiku目前在DeepMind和Google的数百名研究人员中得到积极使用,并已在几个外部项目中得到采用(例如Coax,DeepChem,NumPyro)。它构建在Sonnet的API之上,Sonnet是我们在TensorFlow中用于神经网络的基于模块的编程模型,我们的目标是尽可能简化从Sonnet迁移到Haiku的过程。

在GitHub上了解更多

Optax

基于梯度的优化是机器学习的基础。Optax提供了一系列梯度变换库,以及组合运算符(例如chain),可以在一行代码中实现许多标准优化器(例如RMSProp或Adam)。

Optax的组合性质自然支持在自定义优化器中重组相同的基本组件。它还提供了一些用于随机梯度估计和二阶优化的实用工具。

许多Optax用户已经采用了Haiku,但根据我们的渐进式采用理念,任何将参数表示为JAX树结构的库都得到支持(例如Elegy,Flax和Stax)。请在这里了解更多关于JAX库丰富生态系统的信息。

在GitHub上了解更多

RLax

我们最成功的项目之一是深度学习和强化学习(RL)的交叉领域,也被称为深度强化学习。RLax是一个为构建RL代理提供有用构建块的库。

RLax中的组件涵盖了广泛的算法和思想:TD学习,策略梯度,演员评论家,MAP,近端策略优化,非线性值转换,通用值函数以及一些探索方法。

虽然提供了一些入门示例代理,但RLax并不打算成为构建和部署完整RL代理系统的框架。一个使用RLax组件构建的全功能代理框架的示例是Acme。

在GitHub上了解更多

Chex

测试对于软件的可靠性至关重要,研究代码也不例外。从研究实验中得出科学结论需要对代码的正确性有信心。Chex是一个用于验证常见构建块的正确性和稳健性的测试工具集合,供库作者使用,以及供终端用户检查他们的实验代码。

Chex提供了各种实用工具,包括JAX感知的单元测试、JAX数据类型属性的断言、模拟和伪造以及多设备测试环境。Chex在DeepMind的JAX生态系统以及Coax和MineRL等外部项目中使用。

在GitHub上了解更多

Jraph

图神经网络(GNNs)是一个令人兴奋的研究领域,具有许多有前途的应用。例如,参见我们在Google地图上的交通预测和我们在物理模拟方面的工作。Jraph(发音为”giraffe”)是一个轻量级的库,用于支持在JAX中使用GNNs。

Jraph提供了一种标准化的图数据结构,一组用于处理图的实用工具,以及一系列易于分叉和扩展的图神经网络模型的“动物园”。其他关键特性包括:批处理图元组,可以高效利用硬件加速器;通过填充和屏蔽支持可变形状图的即时编译;以及在输入分区上定义的损失。与Optax和我们的其他库一样,Jraph不对用户选择的神经网络库施加任何限制。

了解如何使用该库的更多信息,请参阅我们丰富的示例集合。

在GitHub上了解更多

我们的JAX生态系统不断发展,我们鼓励机器学习研究社区探索我们的库以及JAX加速他们自己的研究的潜力。