网站首页 > 技术文章 正文
图:pixabay
原文来源:github、arxiv
作者:Stephen Merity
「机器人圈」编译:嗯~阿童木呀、多啦A亮
具有权重下降LSTM的平均随机梯度下降
该存储库包含用于Salesforce Research的论文《正规化和优化LSTM语言模型》(
https://arxiv.org/abs/1708.02182)的代码,最初派生于PyTorch词级语言建模样本(
https://github.com/pytorch/examples/tree/master/word_language_model)。同时,该模型还附带了在Penn Treebank(PTB)和WikiText-2(WT2)数据集上训练词级语言模型的指令,尽管该模型可能扩展到许多其他数据集上。
安装PyTorch 0.1.12_2;
运行getdata.sh以便获取Penn Treebank和WikiText-2数据集;
使用main.py训练基本模型;
使用finetune.py 微调模型;
使用pointer.py将连续缓存指针应用于finetuned模型。
如果你在研究中使用此代码或我们的研究成果,请引用:
@article{merityRegOpt, title={{Regularizing and Optimizing LSTM Language Models}}, author={Merity, Stephen and Keskar, Nitish Shirish and Socher, Richard}, journal={arXiv preprint arXiv:1708.02182}, year={2017} }
软件要求
该代码库需要Python 3和PyTorch 0.1.12_2。
注意:旧版本的PyTorch升级到更高版本将需要做一些微小的更新,并将阻止下面结果的精准复制。欢迎更新到稍后PyTorch版本中,特别是如果他们已有基准数据报告。
实验
在撰写本文时,代码库已被修改,阻止了由于随机种子或类似的微小差异而导致的精确复制。下面的指南产生的结果大体上与所报告的数字相类似。
对于数据设置,运行./getdata.sh。该脚本收集了Mikolov预处理的Penn Treebank和WikiText-2数据集,并将它们放在data目录中。
重要提示:如果你想要在基本实验基础之上继续实验,请注释测试代码并使用验证指标,直到报告出你的最终结果。这是正确的实验实践,并且在调整超参数(如指针使用的参数)时尤为重要。
Penn Treebank(PTB)
下面的指令是在PTB上训练模型,即在没有微调的情况下达到了61.2 / 58.9(验证/测试)的复杂度,而在微调的情况下达到了58.8 / 56.6的复杂度,在配置连续缓存指针后达到了53.5 / 53.0的复杂度。
首先,训练模型:
python main.py --batch_size 20 --data data/penn --dropouti 0.4 --seed 28 --epoch 300 --save PTB.pt
第一轮训练应该会产生一个结果为308.03的验证复杂度。
然后微调模型:
python finetune.py --batch_size 20 --data data/penn --dropouti 0.4 --seed 28 --epoch 300 --save PTB.pt
第一轮训练后的验证复杂度应为60.85。
注意:微调修改后以PTB.pt形式保存原始模型—如果你想要保留原始权重,必须复制文件。
最后,运行指针:
python pointer.py --data data/penn --save PTB.pt --lambdasm 0.1 --theta 1.0 --window 500 --bptt 5000
请注意,本文中的模型训练了500轮次,批量大小为40,而上述模型的这两个参数值分别为300和20。如本文中所述,该指针的窗口大小选择为500而不是2000。
注意:BPTT只是改变了推送到GPU上的序列的长度,但不会影响最终的结果。
WikiText-2(WT2)
下面的指令是在WT2上训练模型,即在没有微调的情况下达到了69.1 / 66.1(验证/测试)的复杂度,而在微调的情况下达到了68.7 / 65.8的复杂度,在配置连续缓存指针后达到了53.6 / 52.0(特别情况下还有会51.95)的复杂度。
python main.py --seed 20923 --epochs 750 --data data/wikitext-2 --save WT2.pt
第一轮训练应该会产生一个结果为629.93的验证复杂度。
python -u finetune.py --seed 1111 --epochs 750 --data data/wikitext-2 --save WT2.pt
第一轮训练后的验证复杂度应该为69.14。
注意:微调修改后以PTB.pt形式保存原始模型,如果要保留原始权重,必须复制文件。
最后,运行指针:
python pointer.py --save WT2.pt --lambdasm 0.1279 --theta 0.662 --window 3785 --bptt 2000 --data data/wikitext-2
注意:
速度
对LSTM的增强有几下几种,包括我们的DropConnect(Wan等人于2013年提出)的变体,我们称之为权重下降,它能够增加循环丢失,允许使用NVIDIA的cuDNN LSTM实现。如果在安装了cuDNN的CUDA上运行,PyTorch将自动使用cuDNN后端。这就保证了模型的快速训练,即使在融合了可能需要数百个训练轮次的情况下也是如此。
在NVIDIA Quadro GP100训练期间,该模型的默认速度为:
Penn Treebank:批量大小为40的每轮训练时间约为45秒,批量大小为20的每轮训练时间约为65秒。
WikiText-2:批量大小为80的每轮训练时间约为105秒。
在K80上,速度要慢三倍左右。在K80或其他具有较少内存的存储卡上,你可能希望启用最大采样序列长度的上限以防止内存不足(OOM)错误,特别是对于WikiText-2而言。
如果速度是一个主要问题,那么SGD会比ASGD的非单调触发变体的收敛速度要快得多,但是整体上的复杂度也要糟糕得多。
在计算机领域内,一般来说,诸如长短期记忆网络(LSTM)这样的循环神经网络(RNNs)往往充当的是包括机器翻译、语言建模和问答等许多序列学习任务的基本构建模块。在本文中,我们考虑了词级语言建模的具体问题,并研究了正则化和优化基于LSTM模型的策略。我们建议使用权重下降的LSTM,它在hidden-to hidden权重中使用DropConnect作为一种循环正则化的形式。此外,我们还引入NT-ASGD,它是平均随机梯度法的变体,即使用非单调条件来确定平均触发器,而不是由用户进行调整。使用这些和其他正则化策略,我们在两个数据集上达到了最高水平的词级复杂度:在Penn Treebank 的57.3和在WikiText-2上的65.8。在探索神经缓存与我们提出的模型相结合的有效性时,我们达到了比当前最高水平的复杂度还要低的结果:在Penn Treebank上的52.8和在WikiText-2上的52.0。
猜你喜欢
- 2025-10-02 基于深度学习的铸件缺陷检测_如何控制和检测铸件缺陷?有缺陷铸件如何处置?
- 2025-10-02 Linux Mint 22.1 Cinnamon Edition 搭建深度学习环境
- 2025-10-02 NVIDIA Jetson Nano 2GB 系列文章(53):TAO模型训练工具简介
- 2025-10-02 使用ONNX和Torchscript加快推理速度的测试
- 2025-10-02 tensorflow GPU环境安装踩坑日记_tensorflow配置gpu环境
- 2025-10-02 Keye-VL-1.5-8B 快手 Keye-VL— 腾讯云两卡 32GB GPU保姆级部署指南
- 2024-08-08 faster rcnn在ubuntu环境下使用GPU模式并用cuDNN v5加速
- 2024-08-08 这个中文深度学习书火了!基于TF 2.0,GitHub热榜第一,斩获2K星
- 2024-08-08 小白也能搞定!Windows10上CUDA9.0+CUDNN7.0.5的完美安装教程
- 2024-08-08 使用CUDA語言涉及安裝CUDA工具包、配置環境變量、...
- 10-02基于深度学习的铸件缺陷检测_如何控制和检测铸件缺陷?有缺陷铸件如何处置?
- 10-02Linux Mint 22.1 Cinnamon Edition 搭建深度学习环境
- 10-02AWD-LSTM语言模型是如何实现的_lstm语言模型
- 10-02NVIDIA Jetson Nano 2GB 系列文章(53):TAO模型训练工具简介
- 10-02使用ONNX和Torchscript加快推理速度的测试
- 10-02tensorflow GPU环境安装踩坑日记_tensorflow配置gpu环境
- 10-02Keye-VL-1.5-8B 快手 Keye-VL— 腾讯云两卡 32GB GPU保姆级部署指南
- 10-02Gateway_gateways
- 最近发表
-
- 基于深度学习的铸件缺陷检测_如何控制和检测铸件缺陷?有缺陷铸件如何处置?
- Linux Mint 22.1 Cinnamon Edition 搭建深度学习环境
- AWD-LSTM语言模型是如何实现的_lstm语言模型
- NVIDIA Jetson Nano 2GB 系列文章(53):TAO模型训练工具简介
- 使用ONNX和Torchscript加快推理速度的测试
- tensorflow GPU环境安装踩坑日记_tensorflow配置gpu环境
- Keye-VL-1.5-8B 快手 Keye-VL— 腾讯云两卡 32GB GPU保姆级部署指南
- Gateway_gateways
- Coze开源本地部署教程_开源canopen
- 扣子开源本地部署教程 丨Coze智能体小白喂饭级指南
- 标签列表
-
- cmd/c (90)
- c++中::是什么意思 (84)
- 标签用于 (71)
- 主键只能有一个吗 (77)
- c#console.writeline不显示 (95)
- pythoncase语句 (88)
- es6includes (74)
- sqlset (76)
- apt-getinstall-y (100)
- node_modules怎么生成 (87)
- chromepost (71)
- flexdirection (73)
- c++int转char (80)
- mysqlany_value (79)
- static函数和普通函数 (84)
- el-date-picker开始日期早于结束日期 (76)
- js判断是否是json字符串 (75)
- c语言min函数头文件 (77)
- asynccallback (87)
- localstorage.removeitem (74)
- vector线程安全吗 (70)
- java (73)
- js数组插入 (83)
- mac安装java (72)
- 无效的列索引 (74)