快速浏览Apache MXNet 1.2中的Swish激活函数

Apache MXNet 1.2是即将发布。正如更改日志所暗示的,这看起来像是一个主要版本,但在本文中,我们将重点介绍一个新的激活函数:Swish。

快速回顾激活函数

在深度神经网络中,激活函数的目的是引入非线性,即对神经元输出执行非线性决策阈值。从某种意义上说,我们试图模仿一种简单的方式,无疑是生物神经元的行为。

随着时间的推移,许多激活函数被设计出来,每一个新函数都试图克服其前辈的缺点。例如, Rectified Linear Unit function(又名ReLU)通过求解消失梯度问题对Sigmoid函数进行了改进。

当然,更好的激活函数的竞赛从未停止。2017年末,发现了一个新函数:Swish。

Swish函数

通过自动组合不同的数学运算符,Prajit Ramachandran,Barret Zoph和Quoc V评估了大量候选激活函数的性能。其中一个,他们称之为Swish,结果比其他好。

快速浏览Apache MXNet 1.2中的Swish激活函数

f(x)=x⋅sigmoid(βx)

正如你所看到的,如果β参数小,Swish是接近线性,当β参数很小,接近ReLU时大。β的最佳点似乎在1到2之间:它创建一个non-monotonic“bump”对于负似乎有趣的属性。

正如作者所强调的那样:“ 简单地用Swish单元代替ReLUs,对于ImageNet来说,移动NASNet-A和Inception-ResNet-v2的分类精度分别提高了0.9%和0.6%。

这听起来像是一个简单的改进,不是吗?让我们在MXNet上测试它!

Swish in MXNet

Swish可用于MXNet 1.2中的Gluon API。它在incubator-mxnet / python / mxnet / gluon / nn / activations.py中定义,并且在我们的Gluon代码中使用它很容易:nn.Swish()。

为了评估它的性能,我们将在CIFAR-10上训练VGG16卷积神经网络的两个不同版本:

  • 在Gluon model zoo中使用批量标准化的VGG16 ,即使用ReLU(incubator-mxnet / python / mxnet / gluon / model_zoo / vision / vgg.py)

  • 相同的网络被修改为使用Swish作为卷积层和全连接层。

这非常简单:从主分支开始,我们只需创建一个vggswish.py文件并用Swish代替ReLU,例如:

self.features.add(nn.Dense(4096,activation ='relu',weight_initializer ='normal',bias_initializer ='zeros'))

--->

self.features.add(nn.Dense(4096,weight_initializer ='normal',bias_initializer ='zeros'))

self.features.add(nn.Swish())

然后,我们将这组新模型插入到incubator-mxnet / python / mxnet / gluon / model_zoo / __ init__.py和voila!

Training on CIFAR-10

MXNet包含一个图像分类脚本,可让我们使用各种网络体系结构和数据集(incubator-mxnet / example / gluon / image_classification.py)进行训练。

我们将使用SGD进行时代步骤,每次将学习率除以10。

python image_classification.py --model vgg16_bn --batch-size = 128 --lr = 0.1 --lr-steps = '10,20,30'--epochs = 40

python image_classification.py --model vgg16_bn_swish --batch-size = 128 --lr = 0.1 --lr-steps = '10,20,30'--epochs = 40

结果如下。

快速浏览Apache MXNet 1.2中的Swish激活函数

VGG16与ReLU与VGG16与Swish

从一个单一的例子中得出结论总是困难的(尽管我确实做了很多不同的训练,结果是一致的)。在这里,我们可以看到比Swish版本似乎更快的训练和验证以及ReLU版本。

最高验证精度分别为14次迭代0.866186和19次迭代0.866001。差别很小,但是VGG16并不是一个非常深的网络。

相关推荐