7个使用PyTorch的技巧,含在线代码示例!网友:我连第一个都不知道?!

7个使用PyTorch的技巧,含在线代码示例!网友:我连第一个都不知道?!
2021年05月14日 15:45 量子位

大家在使用PyTorch时,是不是也踩过不少坑?

现在,Reddit上的一位开发者根据他曾经犯过的错和经常忘记的点,总结了七点使用PyTorch的小技巧,供大家参考。

该分享目前在Reddit上得到了300+的支持。

很多人表示很有用,并有人指出这些不仅仅是tips,是每个人在使用Pytorch之前应该阅读的教程的一部分

这位分享者还提供了在线代码示例和视频演示。

接下来就为大家一一展示,请大家按需汲取!

1、使用device参数直接在目标设备创建张量

这样速度会更快!在线示例代码显示,直接在GPU上创建只需0.009s:

对此,有网友补充道,之所以这样更快,是因为使用device参数是直接在GPU上创建张量,而不是在CPU上创建后再复制到GPU。

并且这样以来,使用的RAM更少,也不会留下CPU张量hanging around的风险

2、可能的话使用Sequential层

为了代码更干净。

下面是部分示例代码:

3、不要列层列表

因为它们不能被nn.Module类正确注册。相反,应该将层列表作为未打包的参数传递到一个Sequential层中。

以上两点有争议:有人认为从代码正确性来看,使用nn.Sequential没毛病,但是从代码可读性来看,应该使用nn.ModuleList,除非只是在堆叠(stack)层。

他还给出了官方链接佐证(详情可见文末链接[3]),该观点得到了不少赞同。

另外针对第三点建议,有人不明白如何将列表作为未打包的参数传递给Sequential,并获得相同的结果。

有人作出了解答:两者都可索引寻址和遍历。只是ModuleList只保存不知道如何使用它们的模块,而sequential则按它们在列表中的顺序运行层。

下面是分享者提供的示例代码:

4、充分利用torch.distributions

PyTorch有一些不错的对象和函数用于distribution,但这位开发者认为它们在torch.distributions中没有得到充分利用。可以这样使用:

5、对长度量(Long-Term Metrics)使用detach()

在两个epochs之间存储张量度量时,请确保对它们调用.detach(),以避免内存泄漏

6、删除模型时,使用torch.cuda.empty_cache()清除GPU缓存

尤其是在使用笔记本删除并重新创建大型模型时。

7、预测之前一定记得调用model.eval()

是不是很多人都忘记了?

如果你忘记调用model.eval(),也就是忘记将模型转变为evaluation(测试)模式,那么Dropout层和Batch Normalization层就会对你的预测数据造成干扰。

以上就是这位开发者总结的7点PyTorch使用小技巧。

有人表示,“我居然连第一个技巧都不知道”!

你是否知道呢?

最后,如果你对哪点有疑问或还有其他使用PyTorch时的小技巧,欢迎在评论区开麦!

在线代码示例:https://colab.research.google.com/drive/15vGzXs_ueoKL0jYpC4gr9BCTfWt935DC

视频演示:https://www.youtube.com/watch?v=BoC8SGaT3GE

参考链接:

[1]https://www.reddit.com/r/MachineLearning/comments/n9fti7/d_a_few_helpful_pytorch_tips_examples_included/

[2]https://gist.github.com/ejmejm/1baeddbbe48f58dbced9c019c25ebf71

[3]https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html

财经自媒体联盟更多自媒体作者

新浪首页 语音播报 相关新闻 返回顶部