使用🤗 Transformers对低资源ASR进行微调XLSR-Wav2Vec2
'使用🤗 Transformers微调XLSR-Wav2Vec2低资源ASR'
新(11/2021):此博客帖子已更新,以介绍XLSR的后继者,称为XLS-R。
Wav2Vec2是一个用于自动语音识别(ASR)的预训练模型,由Alexei Baevski、Michael Auli和Alex Conneau于2020年9月发布。在Wav2Vec2在最受欢迎的英语ASR数据集之一LibriSpeech上展示出卓越性能后,Facebook AI推出了Wav2Vec2的多语言版本,称为XLSR。XLSR代表跨语言语音表示,指的是模型学习跨多种语言有用的语音表示的能力。
XLSR的后继者,简称为XLS-R(指“语音的XLM-R”),由Arun Babu、Changhan Wang、Andros Tjandra等人于2021年11月发布。XLS-R在自监督预训练中使用了近500,000小时的128种语言的音频数据,并提供了从3亿到20亿个参数的不同规模的模型。您可以在🤗 Hub上找到预训练的检查点:
- Wav2Vec2-XLS-R-300M
- Wav2Vec2-XLS-R-1B
- Wav2Vec2-XLS-R-2B
与BERT的掩码语言建模目标类似,XLS-R通过在自监督预训练期间随机掩盖特征向量,然后将其传递给变换器网络来学习上下文化的语音表示(即左侧的图表)。
在微调过程中,我们在预训练网络的顶部添加了一个线性层,以便在带标签的音频下游任务(如语音识别、语音翻译和音频分类)上训练模型(即右侧的图表)。
XLS-R在语音识别、语音翻译和说话人/语言识别方面显示出令人印象深刻的改进,请参考官方论文的表3-6、表7-10和表11-12。
安装
在本博客中,我们将详细解释如何微调XLS-R,特别是预训练检查点Wav2Vec2-XLS-R-300M,用于ASR。
为了演示目的,我们将在仅包含约4小时验证训练数据的低资源ASR数据集Common Voice上微调模型。
XLS-R使用连接主义时序分类(CTC)进行微调,CTC是一种用于训练序列到序列问题的神经网络的算法,例如ASR和手写识别。
我强烈推荐阅读Awni Hannun的精心撰写的博客文章《使用CTC进行序列建模(2017年)》。
在开始之前,让我们安装datasets
和transformers
。此外,我们还需要torchaudio
来加载音频文件,以及jiwer
来使用单词错误率(WER)指标评估我们微调的模型 1 {}^1 1 。
!pip install datasets==1.18.3
!pip install transformers==4.11.3
!pip install huggingface_hub==0.1
!pip install torchaudio
!pip install librosa
!pip install jiwer
我们强烈建议在训练过程中直接将训练检查点上传到Hugging Face Hub。Hugging Face Hub具有集成的版本控制,因此您可以确保在训练过程中不会丢失任何模型检查点。
为此,您需要存储来自Hugging Face网站的身份验证令牌(如果尚未注册,请在此处注册!)
from huggingface_hub import notebook_login
notebook_login()
输出:
登录成功
您的令牌已保存到/root/.huggingface/token
然后您需要安装Git-LFS来上传您的模型检查点:
apt install git-lfs
1 {}^1 1 在论文中,模型使用音素错误率(PER)进行评估,但在ASR中最常见的度量标准是词错误率(WER)。为了尽可能保持笔记本的通用性,我们决定使用WER评估模型。
准备数据、分词器、特征提取器
ASR模型将语音转录为文本,这意味着我们既需要一个处理语音信号并将其转换为模型输入格式(例如特征向量)的特征提取器,也需要一个处理模型输出格式并将其转换为文本的分词器。
在🤗 Transformers中,XLS-R模型配备了一个名为Wav2Vec2CTCTokenizer的分词器和一个名为Wav2Vec2FeatureExtractor的特征提取器。
让我们首先创建一个分词器,将预测的输出类解码为输出转录。
创建Wav2Vec2CTCTokenizer
预训练的XLS-R模型将语音信号映射到一系列上下文表示,如上图所示。然而,对于语音识别,模型必须将这个上下文表示序列映射到对应的转录,这意味着必须在Transformer块的顶部添加一个线性层(在上图中以黄色显示)。这个线性层用于将每个上下文表示分类为一个标记类,类似于在预训练后在BERT的嵌入层之上添加线性层以进行进一步的分类(参见以下博文中的“BERT”部分)。在预训练后,将在BERT的嵌入层之上添加线性层以进行进一步的分类-参见本博文中的“BERT”部分。
该层的输出大小对应于词汇表中的标记数,并且不取决于XLS-R的预训练任务,而仅取决于用于微调的标记数据集。因此,在第一步中,我们将查看Common Voice选择的数据集,并根据转录定义一个词汇表。
首先,让我们转到Common Voice官方网站,并选择一种语言进行XLS-R的微调。在本笔记本中,我们将使用土耳其语。
对于每个特定语言的数据集,您可以找到与您选择的语言对应的语言代码。在Common Voice上,查找“Version”字段。然后,语言代码对应于下划线之前的前缀。例如,对于土耳其语,语言代码是"tr"
。
太好了,现在我们可以使用🤗 Datasets的简单API下载数据。数据集名称为"common_voice"
,配置名称对应于语言代码,在我们的案例中是"tr"
。
Common Voice有许多不同的拆分,包括invalidated
,指的是未被评定为“足够清洁”以被视为有用的数据。在本笔记本中,我们只使用"train"
,"validation"
和"test"
这些拆分。
由于土耳其数据集非常小,我们将合并验证和训练数据为一个训练数据集,并仅使用测试数据进行验证。
from datasets import load_dataset, load_metric, Audio
common_voice_train = load_dataset("common_voice", "tr", split="train+validation")
common_voice_test = load_dataset("common_voice", "tr", split="test")
许多ASR数据集只提供每个音频数组'audio'
和文件'path'
的目标文本'sentence'
。Common Voice实际上提供了关于每个音频文件的更多信息,例如'accent'
等。为了使笔记本尽可能通用,我们只考虑用于微调的转录文本。
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
让我们编写一个简短的函数来显示数据集的一些随机样本,并运行几次以了解转录的感觉。
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML
def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(dataset), "无法选择比数据集中的元素更多的元素。"
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
display(HTML(df.to_html()))
打印输出:
好的!转录看起来相当干净。在翻译了转录的句子之后,似乎语言更对应于书面文本而不是杂乱的对话。考虑到 Common Voice 是一个众包的朗读语音语料库,这是有道理的。
我们可以看到转录中包含一些特殊字符,例如,.?!;:
。如果没有语言模型,要将语音块分类到这些特殊字符会更加困难,因为它们实际上并不对应于一个特征音单元。例如,字母"s"
有一个相对清晰的发音,而特殊字符"."
则没有。此外,为了理解语音信号的含义,通常不需要在转录中包含特殊字符。
让我们只删除所有不对单词的含义有贡献且实际上不能用声音表示的字符,并将文本标准化。
import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'
def remove_special_characters(batch):
batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
return batch
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)
让我们再次查看处理后的文本标签。
show_random_elements(common_voice_train.remove_columns(["path","audio"]))
打印输出:
很好!看起来更好了。我们已经从转录中删除了大部分特殊字符,并将它们标准化为小写字母。
在最终进行预处理之前,与目标语言的母语人士进行咨询,看看文本是否可以进一步简化,这总是有利的。对于这篇博文,Merve 非常友好地看了一眼,并指出土耳其语中不再使用“带帽”的字符,如â
,可以用它们的“无帽”等效字符替代,例如a
。
这意味着我们应该将句子"yargı sistemi hâlâ sağlıksız"
替换为"yargı sistemi hala sağlıksız"
。
让我们编写另一个简短的映射函数来进一步简化文本标签。请记住,文本标签越简单,模型学习预测这些标签就越容易。
def replace_hatted_characters(batch):
batch["sentence"] = re.sub('[â]', 'a', batch["sentence"])
batch["sentence"] = re.sub('[î]', 'i', batch["sentence"])
batch["sentence"] = re.sub('[ô]', 'o', batch["sentence"])
batch["sentence"] = re.sub('[û]', 'u', batch["sentence"])
return batch
common_voice_train = common_voice_train.map(replace_hatted_characters)
common_voice_test = common_voice_test.map(replace_hatted_characters)
在 CTC 中,常常将语音块分类为字母,所以我们在这里也会这样做。让我们提取训练和测试数据的所有不同字母,并从这组字母构建我们的词汇表。
我们编写一个映射函数,将所有转录连接成一个长转录,然后将字符串转换为字符集。重要的是向map(...)
函数传递参数batched=True
,以便映射函数一次性访问所有转录。
def extract_all_chars(batch):
all_text = " ".join(batch["sentence"])
vocab = list(set(all_text))
return {"vocab": [vocab], "all_text": [all_text]}
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
现在,我们创建训练数据集和测试数据集中所有不同字母的并集,并将结果列表转换为枚举字典。
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
打印输出:
{
' ': 0,
'a': 1,
'b': 2,
'c': 3,
'd': 4,
'e': 5,
'f': 6,
'g': 7,
'h': 8,
'i': 9,
'j': 10,
'k': 11,
'l': 12,
'm': 13,
'n': 14,
'o': 15,
'p': 16,
'q': 17,
'r': 18,
's': 19,
't': 20,
'u': 21,
'v': 22,
'w': 23,
'x': 24,
'y': 25,
'z': 26,
'ç': 27,
'ë': 28,
'ö': 29,
'ü': 30,
'ğ': 31,
'ı': 32,
'ş': 33,
'̇': 34
}
很棒,我们看到数据集中包含了字母表中的所有字母(这并不令人惊讶),我们还提取了特殊字符 ""
和 '
。请注意,我们没有排除这些特殊字符,因为:
模型必须学会预测单词何时结束,否则模型的预测将始终是一个字符序列,这将使得无法将单词分开。
在训练模型之前,我们应该始终记住预处理是一个非常重要的步骤。例如,我们不希望模型因为我们忘记对数据进行规范化而区分 a
和 A
。 a
和 A
之间的区别根本不取决于字母的“音” ,而更多的是语法规则 – 例如在句子开头使用大写字母。因此,消除大写和非大写字母之间的差异是明智的,这样模型在学习转录语音时会更容易。
为了更清楚地说明 " "
具有自己的标记类,我们给它一个更明显的字符 |
。此外,我们还添加了一个“未知”标记,以便模型稍后可以处理在 Common Voice 的训练集中未遇到的字符。
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
最后,我们还添加了一个填充标记,该标记对应于 CTC 的“空白标记”。“空白标记”是 CTC 算法的核心组件。有关更多信息,请参阅此处的“对齐”部分。
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
很棒,现在我们的词汇表已经完整,由39个标记组成,这意味着我们将在预训练的 XLS-R 检查点之上添加的线性层的输出维度为39。
现在让我们将词汇表保存为一个 json 文件。
import json
with open('vocab.json', 'w') as vocab_file:
json.dump(vocab_dict, vocab_file)
在最后一步中,我们使用 json 文件将词汇表加载到 Wav2Vec2CTCTokenizer
类的实例中。
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
如果要重复使用刚刚创建的 tokenizer 与此笔记本的微调模型,强烈建议将 tokenizer
上传到 Hugging Face Hub。让我们将要上传文件的 repo 叫做 "wav2vec2-large-xlsr-turkish-demo-colab"
:
repo_name = "wav2vec2-large-xls-r-300m-tr-colab"
并将分词器上传到🤗 Hub。
tokenizer.push_to_hub(repo_name)
太棒了,您可以在https://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-tr-colab
下看到刚刚创建的存储库
创建Wav2Vec2FeatureExtractor
语音是一个连续的信号,为了让计算机处理,首先必须将其离散化,通常称为采样。采样率在此起着重要作用,因为它定义了每秒测量的语音信号数据点数量。因此,使用更高采样率的采样结果更接近真实语音信号的近似,但也需要更多的每秒数值。
预训练的检查点期望其输入数据与其训练的数据从同一分布中进行了更多或更少的采样。以两种不同速率采样的相同语音信号具有非常不同的分布。例如,加倍采样率会导致数据点变长两倍。因此,在微调ASR模型的预训练检查点之前,关键是验证用于预训练模型的数据的采样率是否与用于微调模型的数据的采样率匹配。
XLS-R是在16kHz采样率下对Babel、多语言LibriSpeech(MLS)、Common Voice、VoxPopuli和VoxLingua107的音频数据进行预训练的。Common Voice在其原始形式下采样率为48kHz,因此我们将需要将微调数据降采样到16kHz。
要实例化Wav2Vec2FeatureExtractor
对象,需要以下参数:
feature_size
:语音模型将一系列特征向量作为输入。尽管此序列的长度显然会有所不同,但特征大小不应该变化。在Wav2Vec2的情况下,特征大小为1,因为该模型是在原始语音信号上进行训练的。sampling_rate
:模型训练的采样率。padding_value
:对于批处理推断,较短的输入需要用特定值进行填充。do_normalize
:是否对输入进行零均值单位方差规范化。通常,对输入进行规范化可以提高语音模型的性能。return_attention_mask
:模型在批处理推断中是否应使用attention_mask
。通常情况下,XLS-R模型检查点应始终使用attention_mask
。
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
太棒了,XLS-R的特征提取流程已完全定义!
为了提高用户友好性,特征提取器和分词器被封装到一个名为Wav2Vec2Processor
的类中,因此只需要一个model
和processor
对象。
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
接下来,我们可以准备数据集。
预处理数据
到目前为止,我们还没有看到语音信号的实际值,只看到了转录。除了sentence
之外,我们的数据集还包括两个列名path
和audio
。path
表示音频文件的绝对路径。让我们来看一下。
common_voice_train[0]["path"]
XLS-R期望以16 kHz的1维数组格式输入。这意味着音频文件必须加载和重新采样。
幸运的是,datasets
会自动调用另一列audio
来完成这个过程。让我们试一试。
common_voice_train[0]["audio"]
{'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-8.8930130e-05, -3.8027763e-05, -2.9146671e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
'sampling_rate': 48000}
很好,我们可以看到音频文件已经自动加载了。这要归功于 datasets == 1.18.3
版本中引入的新功能 "Audio"
,在调用时会动态加载和重采样音频文件。
从上面的例子中,我们可以看到音频数据的采样率为 48kHz,而模型期望的采样率是 16kHz。我们可以使用 cast_column
将音频特征设置为正确的采样率:
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))
让我们再次查看一下 "audio"
。
common_voice_train[0]["audio"]
{'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-7.4556941e-05, -1.4621433e-05, -5.7861507e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
'sampling_rate': 16000}
看起来似乎已经生效了!让我们听几个音频文件,以更好地了解数据集,并验证音频是否正确加载。
import IPython.display as ipd
import numpy as np
import random
rand_int = random.randint(0, len(common_voice_train)-1)
print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)
打印输出:
sunulan bütün teklifler i̇ngilizce idi
看起来数据现在已经正确加载和重采样了。
可以听到说话人随着他们的说话速度、口音和背景环境等发生变化。总体上,录音听起来还算清晰,这是可以预期的,因为它是一个众包的朗读语音语料库。
让我们通过打印语音输入的形状、转录文本和相应的采样率来最后检查数据是否准备好。
rand_int = random.randint(0, len(common_voice_train)-1)
print("目标文本:", common_voice_train[rand_int]["sentence"])
print("输入数组形状:", common_voice_train[rand_int]["audio"]["array"].shape)
print("采样率:", common_voice_train[rand_int]["audio"]["sampling_rate"])
打印输出:
目标文本: makedonya bu yıl otuz adet tyetmiş iki tankı aldı
输入数组形状: (71040,)
采样率: 16000
很好!一切看起来都正常 – 数据是一个一维数组,采样率始终对应 16kHz,目标文本已经标准化。
最后,我们可以利用 Wav2Vec2Processor
将数据处理成 Wav2Vec2ForCTC
训练所期望的格式。为此,让我们使用 Dataset 的 map(...)
函数。
首先,我们通过调用 batch["audio"]
来加载和重采样音频数据。其次,我们从加载的音频文件中提取 input_values
。在我们的情况下,Wav2Vec2Processor
只是对数据进行了标准化。然而,对于其他语音模型,这一步骤可能包括更复杂的特征提取,比如对数梅尔特征提取。第三,我们将转录文本编码为标签 id。
注意:这个映射函数是使用 Wav2Vec2Processor
类的好示例。在“正常”上下文中,调用 processor(...)
会重定向到 Wav2Vec2FeatureExtractor
的调用方法。然而,当将处理器包装到 as_target_processor
上下文中时,相同的方法会重定向到 Wav2Vec2CTCTokenizer
的调用方法。更多信息请查看文档。
def prepare_dataset(batch):
audio = batch["audio"]
# batched output is "un-batched"
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["input_length"] = len(batch["input_values"])
with processor.as_target_processor():
batch["labels"] = processor(batch["sentence"]).input_ids
return batch
让我们将数据准备函数应用于所有的例子。
common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)
注意:目前 datasets
使用 torchaudio
和 librosa
进行音频加载和重采样。如果您希望实现自己定制的数据加载/重采样,可以直接使用 "path"
列,并忽略 "audio"
列。
长的输入序列需要大量的内存。XLS-R 是基于 self-attention
的。对于长的输入序列,内存需求随输入长度的平方增长(参见这篇 reddit 帖子)。如果您在运行这个演示时遇到“内存不足”的错误,您可以取消注释以下行,以过滤掉所有长度超过 5 秒的序列:
#max_input_length_in_sec = 5.0
#common_voice_train = common_voice_train.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])
太棒了,现在我们准备开始训练!
训练
数据已经处理好了,我们可以开始设置训练流程。我们将使用 🤗 的 Trainer,以下是我们需要做的:
-
定义一个数据收集器。与大多数 NLP 模型不同,XLS-R 的输入长度要比输出长度长得多。例如,输入长度为 50000 的样本的输出长度不会超过 100。由于输入尺寸较大,使用动态填充训练批次会更加高效,这意味着所有的训练样本只需填充到其批次中最长的样本长度,而不是整个数据集中最长的样本长度。因此,微调 XLS-R 需要使用特殊的填充数据收集器,我们将在下面定义它
-
评估指标。在训练过程中,模型应该根据单词错误率进行评估。我们需要相应地定义一个
compute_metrics
函数 -
加载预训练的检查点。我们需要加载预训练的检查点,并正确配置它以进行训练
-
定义训练配置
在微调模型之后,我们将在测试数据上进行正确评估,并验证它是否确实学会了正确的语音转录。
设置 Trainer
让我们从定义数据收集器开始。数据收集器的代码是从这个示例中复制过来的。
不深入细节,与常见的数据收集器不同,这个数据收集器对待 input_values
和 labels
的方式不同,因此在它们上面应用不同的填充函数(再次利用 XLS-R 处理器的上下文管理器)。这是必要的,因为在语音中输入和输出是不同的模态,意味着它们不应该使用相同的填充函数进行处理。类似于常见的数据收集器,标签中的填充标记为 -100
,这样在计算损失时就不会将这些标记考虑在内。
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
@dataclass
class DataCollatorCTCWithPadding:
"""
动态填充输入数据的数据整理器。
参数:
processor (:class:`~transformers.Wav2Vec2Processor`)
用于处理数据的处理器。
padding (:obj:`bool`, :obj:`str` 或 :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `可选`, 默认值为 :obj:`True`):
选择一种策略来填充返回的序列(根据模型的填充边和填充索引):
* :obj:`True` 或 :obj:`'longest'`: 填充到批次中最长的序列(如果只提供单个序列,则不进行填充)。
* :obj:`'max_length'`: 填充到指定的最大长度(使用参数 :obj:`max_length`)或者填充到模型可接受的最大输入长度(如果未提供该参数)。
* :obj:`False` 或 :obj:`'do_not_pad'` (默认值): 不进行填充(即可以输出长度不同的批次)。
"""
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# 将输入和标签拆分,因为它们的长度必须不同并且需要使用不同的填充方法
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
return_tensors="pt",
)
# 使用 -100 替换填充部分以正确地忽略损失
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
接下来,定义了评估指标。正如前面提到的,ASR 中主要的指标是词错误率 (WER),因此我们也将在本笔记本中使用它。
wer_metric = load_metric("wer")
模型将返回一个 logit 向量序列:y 1 , … , y m \mathbf{y}_1, \ldots, \mathbf{y}_m y 1 , … , y m ,其中 y 1 = f θ ( x 1 , … , x n ) [ 0 ] \mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0] y 1 = f θ ( x 1 , … , x n ) [ 0 ],且 n > > m n >> m n > > m 。
一个 logit 向量 y 1 \mathbf{y}_1 y 1 包含了我们之前定义的词汇表中每个单词的对数几率,因此 len ( y i ) = \text{len}(\mathbf{y}_i) = len ( y i ) = config.vocab_size
。我们对模型的最可能预测感兴趣,因此取 logits 的 argmax(...)
。此外,我们通过用 pad_token_id
替换编码的标签并解码 id 来将其转换回原始字符串,同时确保连续的标记不在 CTC 样式中分组 1 {}^1 1 。
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# 在计算指标时,我们不希望将标记组合在一起
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
现在,我们可以加载预训练的 Wav2Vec2-XLS-R-300M 检查点。为了定义模型的 pad_token_id
或者在 Wav2Vec2ForCTC
的情况下定义 CTC 的 blank token 2 {}^2 2 ,需要设置 tokenizer 的 pad_token_id
。为了节省 GPU 内存,我们启用了 PyTorch 的梯度检查点,并将损失减少设置为 ” mean “。
因为数据集非常小(约6小时的训练数据),而且Common Voice的噪声很大,因此微调Facebook的wav2vec2-xls-r-300m检查点似乎需要一些超参数调整。因此,我不得不尝试不同的值,包括dropout、SpecAugment的掩码dropout率、层dropout和学习率,直到训练似乎足够稳定为止。
注意:当使用这个笔记本在Common Voice的其他语言上训练XLS-R时,这些超参数设置可能效果不佳。根据您的用例自由调整这些参数。
from transformers import Wav2Vec2ForCTC
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-xls-r-300m",
attention_dropout=0.0,
hidden_dropout=0.0,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.0,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)
XLS-R的第一个组件由一堆CNN层组成,用于从原始语音信号中提取语音上下文无关的有意义的特征。这部分模型在预训练过程中已经得到了充分的训练,并且根据论文中的陈述,不需要再进行微调。因此,我们可以将特征提取部分的所有参数的requires_grad
设置为False
。
model.freeze_feature_extractor()
最后,我们定义了与训练相关的所有参数。对于其中一些参数的更多解释:
group_by_length
通过将输入长度相似的训练样本分组成一个批次,使训练更加高效。这可以通过大大减少通过模型的无用填充令牌的总数来显著加快训练时间。learning_rate
和weight_decay
是通过试探法调整的,直到微调稳定为止。请注意,这些参数严重依赖于Common Voice数据集,对于其他语音数据集可能不是最佳选择。
有关其他参数的更多解释,请查看文档。
在训练过程中,每400个训练步骤将异步上传一个检查点到Hub。这样,即使模型仍在训练中,您也可以玩弄演示小部件。
注意:如果不想将模型检查点上传到Hub,请将push_to_hub=False
。
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=repo_name,
group_by_length=True,
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
evaluation_strategy="steps",
num_train_epochs=30,
gradient_checkpointing=True,
fp16=True,
save_steps=400,
eval_steps=400,
logging_steps=400,
learning_rate=3e-4,
warmup_steps=500,
save_total_limit=2,
push_to_hub=True,
)
现在,所有实例都可以传递给Trainer,我们准备开始训练!
from transformers import Trainer
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=common_voice_train,
eval_dataset=common_voice_test,
tokenizer=processor.feature_extractor,
)
1 {}^1 1 为了使模型不依赖于说话者的语速,在CTC中,连续的相同的标记被简单地分组为一个标记。然而,在解码时,编码的标签不应该被分组,因为它们不对应于模型预测的标记,这就是为什么必须传递group_tokens=False
参数的原因。如果不传递这个参数,像"hello"
这样的单词将被错误地编码和解码为"helo"
。
训练
训练将根据分配给该笔记本的GPU而耗费多个小时。尽管经过训练的模型在土耳其Common Voice的测试数据上产生了一定令人满意的结果,但它并不是一个经过最优微调的模型。该笔记本的目的只是演示如何在ASR数据集上微调XLS-R XLSR-Wav2Vec2。
根据分配给您的Google Colab的GPU,可能会出现”out-of-memory”错误。在这种情况下,最好将per_device_train_batch_size
减小为8或更小,并增加gradient_accumulation
。
trainer.train()
输出结果:
训练损失和验证WER都有很好的下降趋势。
您现在可以将训练结果上传到Hub,只需执行以下指令:
trainer.push_to_hub()
您现在可以与所有的朋友、家人、喜爱的宠物共享这个模型,他们可以使用标识符”your-username/the-name-you-picked”加载它,例如:
from transformers import AutoModelForCTC, Wav2Vec2Processor
model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
有关如何微调XLS-R的更多示例,请查看官方的🤗 Transformers示例。
评估
作为最后的检查,让我们加载模型并验证它是否确实学会了转录土耳其语音。
首先加载预训练的检查点。
model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)
现在,我们将只取测试集的第一个示例,将其通过模型运行,并通过argmax(...)
获取对数以检索预测的标记ID。
input_dict = processor(common_voice_test[0]["input_values"], return_tensors="pt", padding=True)
logits = model(input_dict.input_values.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)[0]
强烈建议将sampling_rate
参数传递给此函数。不这样做可能会导致难以调试的静默错误。
我们对common_voice_test
进行了相当大的调整,以使数据集实例不再包含原始句子标签。因此,我们重复使用原始数据集来获取第一个示例的标签。
common_voice_test_transcription = load_dataset("common_voice", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test")
最后,我们可以对示例进行解码。
print("Prediction:")
print(processor.decode(pred_ids))
print("\nReference:")
print(common_voice_test_transcription[0]["sentence"].lower())
输出结果:
好了!从我们的预测中,可以明显地识别出转录内容,但还不完美。通过更长时间的训练模型、更多的数据预处理时间,特别是使用语言模型进行解码,可以显著提高模型的整体性能。
对于一个低资源语言的演示模型来说,结果还是相当可接受的🤗。