用Julia从头开始实现门控循环神经网络
用Julia实现门控循环神经网络
让我们探索Julia从零开始构建带有GRU单元的RNN
1. 简介
一段时间以前,我开始学习Julia进行科学编程和数据科学。Julia的持续受欢迎是由于将R的统计能力、Python的表达力和清晰的语法以及C++等编译语言的高性能相结合。
学习一样东西的最好方式就是不断地练习。这个“简单”的方法在技术领域显然是有效的。只有通过编码和实践,程序员才能掌握和探索语法、数据类型、函数、方法、变量、内存管理、控制流、错误处理以及包括最佳实践和约定在内的库。
基于这种信念,我开始了一个个人项目,构建了一个使用最先进的门控循环单元(GRU)架构的递归神经网络(RNN)。为了增加一些趣味性并增强我对Julia的理解,我从零开始构建了这个RNN。我的想法是使用带有GRU的RNN进行与股市相关的时间序列预测。
Julia中从零开始构建基于密度的聚类算法
让我们在数据科学中使用Julia作为Python的替代方案
pub.towardsai.net
- 一项新的人工智能研究推出了GPT4RoI:一种基于区域-文本对上进行指导调整的大型语言模型(LLM)的视觉-语言模型
- 2023年最佳UI和UX的人工智能工具
- 宾夕法尼亚大学的研究人员引入了一种替代的人工智能方法,用于设计和编程基于循环神经网络的储水池计算机
本文的概述如下:
- 理解GRU架构
- 建立项目
- 实现GRU网络
- 结果和洞察
- 结论
开始、派生、分享,最重要的是,使用为该项目创建的GitHub存储库进行实验👇。
GitHub – jodhernandezbe/post-gru-julia: 这是一个包含Julia代码的存储库,用于从零开始创建…
这是一个包含Julia代码的存储库,用于从零开始创建一个用于股票的门控循环神经网络…
github.com
2. 理解GRU架构
本节的目的不是对GRU架构进行详细描述,而是介绍编写从零开始使用GRU单元的RNN所需的元素。对于新手来说,我可以说RNN属于一类能够处理文本、股票价格和传感器数据等序列数据的模型家族。
揭示隐马尔可夫模型:概念、数学和实际应用
让我们探索隐藏的马尔可夫链
VoAGI.com
GRU的理念是克服普通RNN的梯度消失问题。Chi-Feng Wang撰写的文章可以简单解释这个问题👇。如果你想深入了解GRU,我鼓励你阅读以下易于理解和开源的论文:
- 关于神经机器翻译的特性:编码器-解码器方法
- 对序列建模中的门控循环神经网络进行实证评估
梯度消失问题
问题、原因、意义及解决方案
towardsdatascience.com
本文实现的RNN既不是深度的也不是双向的。Julia集成函数必须能够捕捉到这种行为。如图1所示,带有GRU单元的RNN由一系列连续的阶段组成。在每个阶段t,它提供了与前一阶段的隐藏状态(hₜ₋₁)相对应的元素。类似地,一个元素代表样本序列的第t个元素(即xₜ)。每个GRU单元的输出对应于该时间步的隐藏状态,将被传递给下一个阶段(即hₜ)。此外,hₜ可以通过类似Softmax的函数传递,以获得所需的输出(例如,文本中的一个词是否为形容词)。
图2描述了GRU单元的形成方式以及内部发生的信息流和数学运算。时间步t的单元包含一个更新门(zₜ),用于确定前面信息的哪个部分将传递给下一步,以及一个重置门(rₜ),用于确定前面信息的哪个部分应该被遗忘。使用rₜ、hₜ₋₁和xₜ,计算出当前步骤的候选隐藏状态(ĥₜ)。随后,使用zₜ、hₜ₋₁和ĥₜ计算出实际隐藏状态(hₜ₋₁)。所有这些操作组成了GRU单元中的前向传递,并在图3中的方程式中总结,其中Wᵣₕ、Wᵣₓ、Wₕₕ、Wₕₓ、W₂ₓ、W₂ₕ、bᵣ、bₕ和b₂是可学习的参数。 “*”表示矩阵乘法,而“・”表示逐元素乘法。
在文献中,通常可以找到如图4所示的前向传递方程式。在这个图中,使用矩阵连接来缩短图3中呈现的表达式。Wᵣ、Wₕ和W₂分别是Wᵣₕ和Wᵣₓ、Wₕₕ和Wₕₓ,以及W₂ₓ和W₂ₕ之间的垂直连接。方括号表示其中包含的元素是水平连接的。这两种表示方法都很有用,图4中的表示方法有助于缩短公式,而图3中的表示方法有助于理解反向传播方程式。
图4显示了必须在Julia程序中包含的反向传播方程式,用于模型训练。在这些方程式中,“T”表示矩阵的转置。我们可以通过使用多变量函数的全导数定义和链式法则来得到这些方程式。此外,您可以通过使用图形方法来指导自己👇:
GRUs中的前向传递和反向传播 – 推导 | 深度学习
解释了门控循环单元(GRUs)及其背后的数学,以及损失如何通过时间进行反向传播。
VoAGI.com
GRU单元
要使用GRU单元执行BPTT,我们有来自顶层的错误(\(\delta 1\)),未来隐藏…
cran.r-project.org
3. 设置项目
按照文档中的说明在计算机上安装Julia以运行项目:
官方二进制文件的特定平台说明
Julia语言的官方网站。Julia是一种快速、动态、易于使用和开源的语言…
julialang.org
与Python一样,您可以使用Julia内核的Jupyter笔记本。如果您希望这样做,请查看由Martin McGovern博士撰写的以下文章:
如何最好地使用Julia与Jupyter
如何将Julia代码添加到您的Jupyter笔记本中,并在同一份文件中同时使用Python和Julia…
towardsdatascience.com
3.1. 项目结构
GitHub存储库中的项目具有以下树形结构:
.├── data│ ├── AAPL.csv│ ├── GOOG.csv│ └── IBM.csv├── plots│ ├── residual_plot.png│ └── sequence_plot.png├── Project.toml├── .pre-commit-config.yaml├── src│ ├── data_preprocessing.jl│ ├── main.jl│ ├── prediction_plots.jl│ └── scratch_gru.jl└── tests (单元测试) ├── test_data_preprocessing.jl ├── test_main.jl └── test_scratch_gru.jl
文件夹:
data
:在此文件夹中,您将找到包含训练模型所需数据的.csv
文件。这里存储了股票价格文件。plots
:用于存储模型训练后获得的图形的文件夹。src
:这个文件夹是项目的核心,包含了预处理数据、训练模型、构建RNN架构、创建GRU单元和制作图形所需的.jl
文件。tests
:这个文件夹包含使用Julia构建的单元测试,用于确保代码的正确性和发现错误。对于这个文件夹内容的解释超出了本文的范围。您可以将其用作参考,并告诉我是否希望有一篇文章探讨Test
包。
单元测试
Base.runtests(tests=[“all”]; ncores=ceil(Int, Sys.CPU_THREADS / 2), exit_on_error=false, revise=false, [seed]) 运行…
docs.julialang.org
3.2. 必需包
尽管我们将从头开始,但以下包是必需的:
CSV
(0.10.11):CSV
是Julia中用于处理逗号分隔值(CSV)文件的包。DataFrames
(1.5.0):DataFrames
是Julia中处理表格数据的包。LinearAlgebra
(标准):LinearAlgebra
是Julia中提供了一系列线性代数例程的标准包。Base
(标准):Base
是Julia中提供基本功能和核心数据类型的标准模块。Statistics
(标准):Statistics
是Julia中提供用于数据分析的统计函数和算法的标准模块。ArgParse
(1.1.4):ArgParse
是Julia中用于解析命令行参数的包。它提供了一个简单灵活的方式来定义Julia脚本和应用程序的命令行界面。Plots
(1.38.16):Plots
是Julia中流行的绘图包,提供了一个高级接口用于创建数据可视化。Random
(标准):Random
是Julia中提供生成随机数和处理随机过程的函数的标准模块。Test
(标准,仅限单元测试):Test
是Julia中提供编写单元测试的实用工具的标准模块(超出本文的范围)。
通过使用Project.toml
,可以创建一个包含上述包的环境。这个文件类似于Python中的requirements.txt
或者Conda中的environment.yml
。运行以下命令来安装依赖项:
julia --project=. -e 'using Pkg; Pkg.instantiate()'
3.3. 股票价格
作为数据科学从业者,您了解数据是驱动每个机器学习或统计模型的燃料。在我们的例子中,领域特定的数据来自股票市场。Yahoo Finance提供公开可用的股票市场统计数据。我们将特别查看Google Inc.(GOOG)的历史统计数据。但是,您也可以搜索并下载其他公司的数据,例如IBM和苹果。
Alphabet Inc.(GOOG)股票历史价格和数据 – Yahoo Finance
在Yahoo Finance上发现GOOG股票的历史价格。查看每日、每周或每月格式,回溯到Alphabet…
finance.yahoo.com
4. 实现GRU网络
在src
文件夹中,您可以深入研究用于生成将在第5节中呈现的图表的文件(prediction_plots.jl
),在模型训练之前处理股票价格的文件(data_preprocessing.jl
),训练和构建GRU网络的文件(scratch_gru.jl
),以及一次性整合所有上述文件的文件(main.jl
)。在本节中,我们将深入研究组成GRU网络架构核心的四个函数,这些函数用于在训练期间实现前向传播和反向传播。
4.1. gru_cell_forward函数
下面呈现的代码片段对应于gru_cell_forward
函数。该函数接收当前输入(x
),上一个隐藏状态(prev_h
)和参数字典作为输入(parameters
)。通过上述参数,该函数实现了GRU单元前向传播的一个步骤,并使用sigmoid
和tanh
函数计算更新门(z
)、重置门(r
)、新的记忆单元或候选隐藏状态(h_tilde
)和下一个隐藏状态(next_h
),还计算了GRU单元的预测值(y_pred
)。该函数内部实现了图3和图4中的方程式。
4.2. gru_forward函数
与gru_cell_forward
不同,gru_forward
执行GRU网络的前向传播,即一系列时间步的前向传播。该函数接收输入张量(x
),初始隐藏状态(ho
)和参数字典作为输入(parameters
)。
如果您对序列模型还不熟悉,请不要将时间步与迭代次数混淆,以便最小化模型误差。
请不要将gru_cell_forward
接收的x
与gru_forward
接收的x
混淆。在gru_forward
中,x
有三个维度,而不是两个。第三个维度对应于RNN层拥有的总GRU单元数。简而言之,gru_cell_forward
与图2相关,而gru_forward
与图1相关。
gru_forward
在序列中的每个时间步上进行迭代,调用gru_cell_forward
函数计算next_h
和y_pred
。它分别将结果存储在h
和y
中。
4.3. gru_cell_backward函数
gru_cell_backward
函数执行单个GRU单元的反向传播。gru_cell_forward
接收隐藏状态的梯度(dh
)作为输入,同时包含用于计算图4中导数的元素的cache
(即next_h
,prev_h
,z
,r
,h_tilde
,x
和parameters
)。因此,gru_cell_backward
计算权重矩阵(即Wz
,Wr
和Wh
)和偏差(即bz
,br
和bh
)的梯度。所有梯度都存储在Julia字典(gradients
)中。
4.4. gru_backward函数
gru_backward
函数执行完整GRU网络的反向传播,即对完整时间步序列或GRU网络层中的GRU单元进行反向传播。此函数接收隐藏状态张量(dh
)和caches
的梯度作为输入。与gru_cell_backward
不同,gru_backward
的dh
具有第三个维度,对应于序列中的时间步数或GRU单元数。此函数按相反顺序迭代时间步骤,调用gru_cell_backward
计算每个时间步的梯度,并在循环中累积它们。
需要注意的是,在这个阶段重要的是,该项目只使用梯度下降来更新GRU网络参数,并不包括任何影响学习率或引入动量的功能。此外,该实现是针对回归问题创建的。然而,由于实施了模块化,只需要进行一些小的更改就可以获得不同的行为。
5. 结果和见解
现在让我们运行代码来训练GRU网络。由于集成了ArgParse
包,我们可以使用命令行参数来运行代码。如果您熟悉Python,该过程与之相同。我们将使用训练分割比例为0.7(split_ratio
),序列长度为10(seq_length
),隐藏大小为70(hidden_size
),1000个epochs(num_epochs
)和学习率为0.00001(learning_rate
)进行实验,因为本项目的目的不是优化超参数(这将涉及使用额外的模块)。运行以下命令开始训练:
julia --project src/main.jl --data_file GOOG.csv --split_ratio 0.7 --seq_length 10 --hidden_size 70 --num_epochs 1000 --learning_rate 0.00001
尽管模型经过1000个epochs的训练,但在
train_gru
函数中存在一个流程控制,用于存储参数的最佳值。
图5显示了训练迭代的成本。可以看到,曲线呈下降趋势,并且模型在最后几次迭代中似乎收敛。由于曲线的曲率,通过增加训练GRU网络的epochs数量,可能可以进一步改善。对测试集的外部评估得到的均方误差(MSE)约为6.57。尽管该值可能不接近零,但由于缺乏用于比较的基准值,无法得出最终结论。
图6展示了训练和测试数据集的实际值,其中散点表示,预测值则由连续线条表示(更多细节请参见图例)。很明显,该模型与实际点趋势相匹配;然而,需要更多的训练来改善GRU网络的性能。即便如此,在图片的某些部分,特别是在训练方面,模型过度拟合了一些样本;由于训练集的均方误差约为1.70,可能模型出现了一些过拟合。
误差的波动或不稳定可能是回归分析中的一个困难,包括时间序列预测。在数据科学和统计学中,这被称为异方差性(有关更多信息,请查看下面的文章👇)。残差图是检测异方差性的一种方法。图7说明了残差图,其中x轴表示预测值,y轴表示残差。
回归学习中的异方差性和同方差性
回归分析中残差的可变性
pub.towardsai.net
围绕零值均匀分布的点表示同方差性(即残差稳定)。该图显示了该情景中异方差性的证据,这将需要使用方法来纠正问题(例如对数变换),以创建高性能模型。图7显示,无论样本来自训练集还是测试集,都可以清楚地看到异方差性存在于120美元以上的区域。图6有助于加强这一观点。图6表明,大于120的值与实际数字相差较大。
结论
在本文中,我们使用Julia编程语言从零开始构建了一个GRU网络。我们首先研究了程序需要考虑的数学方程,以及GRU实现成功所需的最关键的理论问题。我们讨论了如何创建初步设置,以处理数据,建立GRU架构,训练模型和评估模型。我们回顾了在模型架构设计中使用的最重要的代码片段。我们运行程序以分析结果。我们检测到这个特定项目中存在异方差性,并建议研究解决这个问题并创建高性能GRU网络的策略。
我邀请您在GitHub上查看代码,并让我知道是否希望我们查看Julia或其他数据科学或编程主题的其他内容。您的想法和反馈对我非常有帮助🚀…
附加材料
- 门控循环单元(GRU)
- 序列模型完全课程
如果您喜欢我的文章,请在VoAGI上关注我,以获取更多引人思考的内容,并与您的同事分享这些材料。