使用GradientTape进行TensorFlow模型训练

使用GradientTape进行TensorFlow模型训练的技巧与方法

Photo by Sivani Bandaru on Unsplash

使用 GradientTape 更新权重

TensorFlow 可能是目前最流行的深度学习库。我之前写了很多 TensorFlow 教程,现在仍在继续。TensorFlow 组织得非常好,使用起来也很简单,你不需要太担心模型的开发和训练。基本上大部分的事情都由库本身来处理,这也是它在工业界如此受欢迎的原因。但与此同时,有时候可以控制底层功能也是很好的。这给了你很多实验模型的权力。如果你是求职者,一些额外的知识可能会给你一个优势。

之前,我写了一篇文章介绍了如何开发自定义激活函数、层和损失函数。在本文中,我们将看到如何手动训练模型和更新权重。但别担心。你不需要重新记住微分法。TensorFlow 自身提供了 GradientTape() 方法来处理这部分。

如果 GradientTape() 对你来说是全新的,请随时查看此练习,了解 GradientTape() 的用法:Introduction to GradientTape in TensorFlow — Regenerative (regenerativetoday.com)

数据准备

在本文中,我们使用 TensorFlow 和 GradientTape() 来进行简单的分类算法。请从这个链接下载数据集:

Heart Failure Prediction Dataset (kaggle.com)

该数据集有一个开放的数据库许可证。

下面是必要的导入语句:

import tensorflow as tffrom tensorflow.keras.models import Modelfrom tensorflow.keras.layers import Dense, Inputimport numpy as npimport matplotlib.pyplot as pltimport matplotlib.ticker as mtickerimport pandas as pdfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import confusion_matriximport itertoolsfrom tqdm import tqdmimport tensorflow_datasets as tfds

创建数据集的 DataFrame:

import pandas as pddf = pd.read_csv('heart.csv')df