优秀的编程知识分享平台

网站首页 > 技术文章 正文

AWD-LSTM语言模型是如何实现的_lstm语言模型

nanyue 2025-10-02 04:55:43 技术文章 2 ℃

图:pixabay

原文来源:github、arxiv

作者:Stephen Merity

「机器人圈」编译:嗯~阿童木呀、多啦A亮

具有权重下降LSTM的平均随机梯度下降

该存储库包含用于Salesforce Research的论文《正规化和优化LSTM语言模型》(
https://arxiv.org/abs/1708.02182)的代码,最初派生于PyTorch词级语言建模样本(
https://github.com/pytorch/examples/tree/master/word_language_model)。同时,该模型还附带了在Penn Treebank(PTB)和WikiText-2(WT2)数据集上训练词级语言模型的指令,尽管该模型可能扩展到许多其他数据集上。

安装PyTorch 0.1.12_2;

运行getdata.sh以便获取Penn Treebank和WikiText-2数据集;

使用main.py训练基本模型;

使用finetune.py 微调模型;

使用pointer.py将连续缓存指针应用于finetuned模型。

如果你在研究中使用此代码或我们的研究成果,请引用:

@article{merityRegOpt, title={{Regularizing and Optimizing LSTM Language Models}}, author={Merity, Stephen and Keskar, Nitish Shirish and Socher, Richard}, journal={arXiv preprint arXiv:1708.02182}, year={2017} }

软件要求

该代码库需要Python 3和PyTorch 0.1.12_2。

注意:旧版本的PyTorch升级到更高版本将需要做一些微小的更新,并将阻止下面结果的精准复制。欢迎更新到稍后PyTorch版本中,特别是如果他们已有基准数据报告。

实验

在撰写本文时,代码库已被修改,阻止了由于随机种子或类似的微小差异而导致的精确复制。下面的指南产生的结果大体上与所报告的数字相类似。

对于数据设置,运行./getdata.sh。该脚本收集了Mikolov预处理的Penn Treebank和WikiText-2数据集,并将它们放在data目录中。

重要提示:如果你想要在基本实验基础之上继续实验,请注释测试代码并使用验证指标,直到报告出你的最终结果。这是正确的实验实践,并且在调整超参数(如指针使用的参数)时尤为重要。

Penn Treebank(PTB)

下面的指令是在PTB上训练模型,即在没有微调的情况下达到了61.2 / 58.9(验证/测试)的复杂度,而在微调的情况下达到了58.8 / 56.6的复杂度,在配置连续缓存指针后达到了53.5 / 53.0的复杂度。

首先,训练模型:

python main.py --batch_size 20 --data data/penn --dropouti 0.4 --seed 28 --epoch 300 --save PTB.pt

第一轮训练应该会产生一个结果为308.03的验证复杂度。

然后微调模型:

python finetune.py --batch_size 20 --data data/penn --dropouti 0.4 --seed 28 --epoch 300 --save PTB.pt

第一轮训练后的验证复杂度应为60.85。

注意:微调修改后以PTB.pt形式保存原始模型—如果你想要保留原始权重,必须复制文件。

最后,运行指针:

python pointer.py --data data/penn --save PTB.pt --lambdasm 0.1 --theta 1.0 --window 500 --bptt 5000

请注意,本文中的模型训练了500轮次,批量大小为40,而上述模型的这两个参数值分别为300和20。如本文中所述,该指针的窗口大小选择为500而不是2000。

注意:BPTT只是改变了推送到GPU上的序列的长度,但不会影响最终的结果。

WikiText-2(WT2)

下面的指令是在WT2上训练模型,即在没有微调的情况下达到了69.1 / 66.1(验证/测试)的复杂度,而在微调的情况下达到了68.7 / 65.8的复杂度,在配置连续缓存指针后达到了53.6 / 52.0(特别情况下还有会51.95)的复杂度。

python main.py --seed 20923 --epochs 750 --data data/wikitext-2 --save WT2.pt

第一轮训练应该会产生一个结果为629.93的验证复杂度。

python -u finetune.py --seed 1111 --epochs 750 --data data/wikitext-2 --save WT2.pt

第一轮训练后的验证复杂度应该为69.14。

注意:微调修改后以PTB.pt形式保存原始模型,如果要保留原始权重,必须复制文件。

最后,运行指针:

python pointer.py --save WT2.pt --lambdasm 0.1279 --theta 0.662 --window 3785 --bptt 2000 --data data/wikitext-2

注意:

速度

对LSTM的增强有几下几种,包括我们的DropConnect(Wan等人于2013年提出)的变体,我们称之为权重下降,它能够增加循环丢失,允许使用NVIDIA的cuDNN LSTM实现。如果在安装了cuDNN的CUDA上运行,PyTorch将自动使用cuDNN后端。这就保证了模型的快速训练,即使在融合了可能需要数百个训练轮次的情况下也是如此。

在NVIDIA Quadro GP100训练期间,该模型的默认速度为:

Penn Treebank:批量大小为40的每轮训练时间约为45秒,批量大小为20的每轮训练时间约为65秒。

WikiText-2:批量大小为80的每轮训练时间约为105秒。

在K80上,速度要慢三倍左右。在K80或其他具有较少内存的存储卡上,你可能希望启用最大采样序列长度的上限以防止内存不足(OOM)错误,特别是对于WikiText-2而言。

如果速度是一个主要问题,那么SGD会比ASGD的非单调触发变体的收敛速度要快得多,但是整体上的复杂度也要糟糕得多。

在计算机领域内,一般来说,诸如长短期记忆网络(LSTM)这样的循环神经网络(RNNs)往往充当的是包括机器翻译、语言建模和问答等许多序列学习任务的基本构建模块。在本文中,我们考虑了词级语言建模的具体问题,并研究了正则化和优化基于LSTM模型的策略。我们建议使用权重下降的LSTM,它在hidden-to hidden权重中使用DropConnect作为一种循环正则化的形式。此外,我们还引入NT-ASGD,它是平均随机梯度法的变体,即使用非单调条件来确定平均触发器,而不是由用户进行调整。使用这些和其他正则化策略,我们在两个数据集上达到了最高水平的词级复杂度:在Penn Treebank 的57.3和在WikiText-2上的65.8。在探索神经缓存与我们提出的模型相结合的有效性时,我们达到了比当前最高水平的复杂度还要低的结果:在Penn Treebank上的52.8和在WikiText-2上的52.0。

最近发表
标签列表