华人本科生发布zero-shot最强的GPT-J!会算数,会编程,运行速度碾压GPT-3

华人本科生发布zero-shot最强的GPT-J!会算数,会编程,运行速度碾压GPT-3
2021年06月13日 12:01 新智元

GPT家族又添了一个新成员GPT-J!

在zero-shot任务上,这个GPT-J的性能和67亿参数的GPT-3(中等模型)相当,也是目前公开可用的Transformer语言模型中,在各种下游zero-shot任务上表现最好的。

与 Tensorflow + TPU 的组合相比,GPT-J 更加灵活,并且有更快速的推断性能。

与其他大规模模型开发相比,这个项目需要的人工时间要少得多,这表明 JAX + xmap + TPUs 是快速开发大规模模型的正确工具集。

最让人惊讶的是,他的作者Ben Wang还只是一个是宾夕法尼亚大学电气工程与管理专业的本科生,他的主要研究方向包括电力电子,嵌入式设备,fpga,区块链,计算机视觉,网络和人工智能。

GPT-J模型

GPT-J的构建基于Mesh Transformer JAX,是一个haiku库,使用 Jax 中的 xmap 操作符来实现Transformer的模型并行化。

Mesh TensorFlow (mtf)是一种用于分布式深度学习的语言,能够指定广泛的分布式张量计算。Mesh TensorFlow 的目的是形式化和实现在硬件/处理器上的计算图的分发策略。例如将批处理分成多行处理器,并将隐藏层中的单元分成多列处理器。在 TensorFlow 上实现了一个网格式的 TensorFlow 层,通常使用场景是大规模的训练和低延迟的并行推理。

Mesh Transformer JAX使用了JAX框架,并且他的并行机制与英伟达提出的original Megatron-lm 相似,由于采用了高速的二维网状网络,因此具有较高的并行效率。

这个库的设计目的是在 tpuv3上最多可伸缩到大约20B 参数,超越了其他的并行策略,如 GPT-NeoX 或 DeepSpeed。

GPT-J-6B模型的训练基于The Pile数据库,总共4000亿个词,使用TPU v3-256训练了5周的时间。

The Pile是一个825GB的, 多样化的开源语言建模数据集,由22个较小的、高质量的数据集合组成。特别是对于大型模型,数据源的多样性提高了模型的一般跨领域知识,以及下游泛化能力。在我们的评估中,不仅在The Pile上训练的模型在传统的语言建模基准中显示了相当的改进,而且在Pile BPB上也显示了显著的改进。

GPT-J的模型设计和超参数选择与6.7 b GPT-3的模型设计和超参数选择有一定的差异,例如使用的数据集The Pile与GPT-3不同;注意力(线性、局部/滑动窗口等)公式没有被用于简化,因为在这种规模下它不会显著提高吞吐量;每个注意头的尺寸设置为256,比同等尺寸的 GPT-3大两倍。这显著提高了吞吐量,性能降低最小。

在架构上还做了两个改进:

为了更好的性能表现,使用Rotary embedding。 

 将注意力层和前馈层并行放置,以减少通信量。

Zero-shot的性能大致相当于 同尺寸的GPT-3,比GPT-Neo的模型性能更强。

6B GPT-J 的训练吞吐量(151k 词/s)比同一硬件(TPU v3-256 pod)上的2.7 b GPT-Neo (148k 词/s)快,效率提高约125%。

在6B 配置的 TPU V3-256 pod,GPT-J 达到高绝对效率。实验结果表明,GPT-J 的理论最大值为13.4 PFLOPs,GPT3论文测量值为5.4 PFLOPs (忽略了注意力计算,忽略了计算-内存的权衡,如梯度检查点)。当考虑到这些额外的因素,8.1 PFLOPs,或大约60% 的理论最大利用。

同时作者提供了一些有趣的样本,可以根据提示,GPT-J来续写,在算数方面也可以有较高的正确率。

官方还提供了一个界面,输入进去后果然可以得到正确的结果:

但对于有些算数问题来说,可能就不太管用了,如下图所示,可能更多的是基于检索,而非计算。

GPT-J甚至还可以定理证明!

虽然该模型在一定程度上模拟了这个简单定理的证明风格,但与人类水平的准确性还有很大差距。

对于自然语言理解的问题更不在话下:

完成 BoolQ (SuperGLUE)提出的一个问题。两种抽样方法都得出了相同的正确结论,但nucleus抽样方法存在错误推理,贪婪抽样方法的答案简洁合理。一般来说,我们观察到贪婪抽样比核抽样更准确,当输出被认为像这样短时,贪婪抽样比核抽样包含更少的误差,这是可以预测的,因为分类任务通常是用贪婪抽样完成的。

还能够编程!

语言模型只能模仿形式,内在的逻辑的理解,也许还有很长的路要走。

参考资料:

https://www.reddit.com/r/MachineLearning/comments/nvkowg/p_gptj_6b_jaxbased_transformer_lm/

财经自媒体联盟更多自媒体作者

新浪首页 语音播报 相关新闻 返回顶部