纯Python实现鸢尾属植物数据集神经网络模型

尝试使用过各大公司推出的植物识别APP吗?比如微软识花、花伴侣等这些APP。当你看到一朵不知道学名的花时,只需要打开植物识别APP,拍摄一张你所想辨认的植物照片并上传,APP会自动识别出该花的品种及详细介绍,感觉手机中装了一个知识渊博的生物学家,是不是很神奇?其实,背后的原理很简单,是一个图像分类的过程,将上传的图像与手机中预存的数据集或联网数据进行匹配,将其分类到对应的类别即可。随着深度学习方法的应用,图像分类的精度越来越高,在部分数据集上已经超越了人眼的能力。

相对于传统神经网络的方法而言,深度学习方法一般对数据集规模、硬件平台有着比较高的要求,如果只是单纯的想尝试了解图像分类任务的基本流程,建议采用小数据集样本及传统的神经网络方法实现。本文将带领读者采用鸢尾属植物数据集(Iris Data Set)来实现一个分类任务,整个鸢尾属植物数据集是机器学习中历史悠久的数据集,比现在常用的数字手写体数据集(Mnist Data Set)数据集还要早得多,该数据集来源于英国著名的统计学家、生物学家Ronald Fiser。本文在不使用相关软件库的情况下,从头开始构建针对鸢尾属植物数据的神经网络模型,对其进行训练并获得好的结果。

纯Python实现鸢尾属植物数据集神经网络模型

鸢尾属植物数据集是用于测试机器学习算法的最常用数据集。该数据包含四种特征,萼片长度、萼片宽度、花瓣长度和花瓣宽度,用于鸢尾属植物的不同物种(versicolor, virginica和setosa)。此外,每个物种有50个实例(数据行),下面让我们看看样本数据分布情况。

纯Python实现鸢尾属植物数据集神经网络模型

我们将在这个数据集上使用神经网络构建分类模型。为了简单起见,使用花瓣长度和花瓣宽度作为特征,且只有两类物种:versicolor和virginica。下面就让我们在Python中逐步训练针对该样本数据集的神经网络:

步骤1:准备鸢尾属植物数据集

将Iris数据集导入python并对数据进行子集划分以保留行之间的相关性:

纯Python实现鸢尾属植物数据集神经网络模型

纯Python实现鸢尾属植物数据集神经网络模型

蓝色点代表Versicolor物种,红色点代表Virginica物种。本文构建的神经网络将在这些数据上进行训练,以期最后能正确地分类物种。

步骤2:初始化参数(权重和偏置)

下面构建一个具有单个隐藏层的神经网络。此外,将隐藏图层的大小设置为6:

纯Python实现鸢尾属植物数据集神经网络模型

步骤3:前向传播(forward propagation)

在前向传播过程中,使用tanh激活函数作为第一层的激活函数,使用sigmoid激活函数作为第二层的激活函数:

纯Python实现鸢尾属植物数据集神经网络模型

步骤4:计算代价函数(cost function)

目标是使得计算的代价函数小化,本文采用交叉熵(cross-entropy)作为代价函数:

纯Python实现鸢尾属植物数据集神经网络模型

步骤5:反向传播(back propagation)

计算反向传播过程,主要是计算代价函数的导数:

纯Python实现鸢尾属植物数据集神经网络模型

步骤6:更新参数

使用反向传播过程中计算的梯度来更新权重和偏置:

纯Python实现鸢尾属植物数据集神经网络模型

步骤7:建立神经网络

将以上所有函数组合起来以创建设计的神经网络模型。总而言之,下面是模型函数的整体顺序:

  • 初始化参数
  • 前向传播
  • 计算代价函数
  • 反向传播
  • 更新参数

纯Python实现鸢尾属植物数据集神经网络模型

纯Python实现鸢尾属植物数据集神经网络模型

步骤8:跑动模型

将隐藏层节点设置为6,最大迭代次数设置为10,000次,并每隔1000次打印出训练的结果:

纯Python实现鸢尾属植物数据集神经网络模型

纯Python实现鸢尾属植物数据集神经网络模型

步骤9:画出分类边界

纯Python实现鸢尾属植物数据集神经网络模型

纯Python实现鸢尾属植物数据集神经网络模型

从图中可以观察到,只有四个点被错误分类。虽然我们可以调整模型来进一步地提高模型训练精度,但该些操作显然会导致过拟合现象的出现。

作者信息

Rohan Joseph,数据科学家

本文由阿里云云栖社区组织翻译。

文章原标题《Neural network on iris data》,译者:海棠,审校:Uncle_LLD。

相关推荐