利用残差神经网络对疟疾细胞图像进行分类

本文描述了构建一个图像分类器的过程和经验教训,该图像分类器能够自动对人类血细胞图像进行分类,从而判断其是否感染了疟疾——因此是一个二元分类。

利用残差神经网络对疟疾细胞图像进行分类

数据集

该项目的灵感来自美国国家医学图书馆的研发部门Lister Hill国家生物医学通信中心的通信工程分部(CEB)的项目。提供的机器学习数据集(https://ceb.nlm.nih.gov/repositories/malaria-datasets/)是平衡的 - 包括总共27,558张细胞图像,其中被寄生(感染)和未感染(清洁)细胞的实例是相等的。对于模型训练/验证,机器学习数据集以80/20的比例拆分。

#from fastai.data_block import *
from fastai.vision import *
import pandas as pd
# Download and unzip the dataset from
# 'https://ceb.nlm.nih.gov/proj/malaria/cell_images.zip'
# to PATH
PATH = 'data/'
DATAPATH = f'{PATH}/cell_images/'
files = get_files(f'{DATAPATH}', extensions='.png', recurse=True)
# Get label from file_path -- folder's name
def get_label(file_path): return 'infected' if '/Parasitized/' in str(file_path) else 'clean'
bs=64 # Batch size
data = ImageDataBunch.from_name_func(f'{DATAPATH}', fnames=files, 
 label_func=get_label, # Parasitized -> infected; Uninfected -> clean
 bs = bs,
 ds_tfms=get_transforms(), 
 size=170 # resize all images
 ).normalize(imagenet_stats)
df = pd.DataFrame(data.y.items)
df['category'] = df[0].replace({0:data.classes[0], 1:data.classes[1]})
data

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

df['category'].value_counts()

clean 11032

infected 11015

Name: category, dtype: int64

# infected cells ratio in the dataset or sample 
data.y.items.sum()/len(data.y.items)

0.4996144600172359

data.show_batch(rows=3, figsize=(7,6))

利用残差神经网络对疟疾细胞图像进行分类

模型1 - ResNet-34

Stage 1:我们将获取预训练的机器学习模型,并在我们的数据上训练它的最后一层。我们将准确度用作指标。

learn = create_cnn(data, models.resnet34, pretrained=True, metrics=accuracy)
learn.fit_one_cycle(8)

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

Stage 1的最佳准确度= 0.964979。

结果分析:

interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
#len(data.valid_ds)==len(losses)==len(idxs)
interp.plot_top_losses(9, figsize=(10,10))

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

以下是混淆矩阵的样子:

doc(interp.plot_top_losses)
interp.plot_confusion_matrix(figsize=(5,5)) #, dpi=60)

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

interp.most_confused(min_val=2)

[('infected', 'clean', 105), ('clean', 'infected', 72)]

Stage 2.现在我们将“Unfreezing”所有层,选择学习率并再次训练整个机器学习模型。

learn.unfreeze() # Enable all layers of NN to learn -- set requires_grad = True
#learn.load('stage-1');
learn.lr_find();
learn.recorder.plot()

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

为了确保最有效的训练,我们选择了一个合适的学习率——大约比上升点低一度。

我们在这里使用的学习率将在从1e-6到1e-5之间。

learn.unfreeze()
learn.fit_one_cycle(4, max_lr=slice(1e-6,1e-5))

利用残差神经网络对疟疾细胞图像进行分类

此时达到的最终精度= 0.966975。然而,它并不是最好的:第3 epoch的准确性更好。尽管如此,ResNet-34模型在Stage 2的准确性略有提高。

interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
#len(data.valid_ds)==len(losses)==len(idxs)
interp.plot_top_losses(9, figsize=(10,10))

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

以下是混淆矩阵的样子:

doc(interp.plot_top_losses)
interp.plot_confusion_matrix(figsize=(5,5))

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

interp.most_confused(min_val=2)

[('infected', 'clean', 101), ('clean', 'infected', 81)]

这意味着:

  • 2764(3.65%)感染细胞中的101个被归类为清洁 - 假阴性 ;
  • 2747(2.95%)清洁细胞中的81个被归类为感染 - 假阳性

模型2 - ResNet-50

Stage 1 只训练最后一层

我们将使用较小的batch size。

data = ImageDataBunch.from_name_func(f'{DATAPATH}', fnames=files, label_func=get_label, 
 bs=bs//2, 
 ds_tfms=get_transforms(), 
 size=170
 ).normalize(imagenet_stats)
learn = create_cnn(data, 
 models.resnet50, 
 pretrained=True, # Leave only the last layer with requires_grad = True
 metrics=accuracy)
learn.lr_find()
learn.recorder.plot()

利用残差神经网络对疟疾细胞图像进行分类

然后选取学习率对机器学习模型进行8个epochs的训练:

learn.fit_one_cycle(8, max_lr=1e-2)

利用残差神经网络对疟疾细胞图像进行分类

ResNet-50在Stage 1 取得的最终结果 - 准确度= 0.967882

一些错误分类的图像:

interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
#len(data.valid_ds)==len(losses)==len(idxs)
interp.plot_top_losses(9, figsize=(10,10))

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

以下是混淆矩阵的样子:

doc(interp.plot_top_losses)
interp.plot_confusion_matrix(figsize=(5,5)) #, dpi=60)

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

interp.most_confused(min_val=2)

[('infected', 'clean', 105), ('clean', 'infected', 72)]

Stage 2

我们现在将Unfreezing ResNet-50模型的所有层,并使用手动选择的学习率再次训练它。

learn.lr_find()
learn.recorder.plot()

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

learn.unfreeze() # Enable all layers of NN to learn -- set requires_grad = True
learn.fit_one_cycle(4, max_lr=slice(2e-5,1e-4))

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

ResNet-50在Stage 2 取得的最佳结果 - 准确度= 0.966975

一些错误分类的图像:

interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
#len(data.valid_ds)==len(losses)==len(idxs)
interp.plot_top_losses(9, figsize=(10,10))

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

以下是混淆矩阵的样子:

doc(interp.plot_top_losses)
interp.plot_confusion_matrix(figsize=(5,5)) #, dpi=60)

利用残差神经网络对疟疾细胞图像进行分类

利用残差神经网络对疟疾细胞图像进行分类

interp.most_confused(min_val=2)

[('infected', 'clean', 103), ('clean', 'infected', 79)]

这意味着:

  • 2708例(3.8%)感染细胞中有103例被归类为清洁 - 假阴性 ;
  • 2803个中有79个(2.82%)清洁细胞被归类为感染 - 假阳性

结论

  • 看下ResNet-50 at Stage 1错误分类的图像,我们可能得出结论,实际上某些图像可能在数据集中被错误地标记 - 明显感染的图像被标记为清洁,反之亦然。
  • 两种机器学习模型的训练准确度相当:

利用残差神经网络对疟疾细胞图像进行分类

  • 但是,Stage 1的ResNet-50更准确。因此,我们可以加载在Stage 1中保存的模型并按原样使用或继续使用不同的超参数(例如另一个学习率)再次训练它以获得更好的结果。
  • 每个epoch的训练时间取决于机器学习模型复杂性 - 层数和训练层数(Stage 1的最后一层与Stage 2的所有层)。此外,它可能与批量大小相关 - 较小的batch size(32 vs. 64)用于ResNet-50。

利用残差神经网络对疟疾细胞图像进行分类

相关推荐