一步步教你使用Head API在TensorFlow中进行多任务学习!

点击上方关注,All in AI中国

人类学习的一个基本特征是我们可以同时学到很多东西。机器学习中的等效思想被称为多任务学习(MTL),它在实践中变得越来越有用,特别是对于强化学习和自然语言处理。事实上,即使在标准的单任务情况下,也可以设计额外的辅助任务并将其包含在优化过程中以帮助学习。

本文通过展示如何在图像分类基准中解决简单的多任务问题来介绍该领域。重点是TensorFlow(Head API)的一个实验组件,它通过将神经网络的共享组件与特定任务组件解耦,帮助设计MTL的自定义估算器。在这个过程中,我们有机会讨论TensorFlow核心的其他功能,包括tf.data,tf.image和自定义估算器。

本教程的代码作为完全包含的Colab笔记本提供,随时可以测试和实验!

(https://colab.research.google.com/drive/1NMB9lpi7P-GkkELkMU0h-yHtUq531D_Z)

内容一目了然

为了使教程更有趣,我们通过重新实现2014年论文的一部分(通过深度多任务学习进行面部特征点检测)来考虑一个现实的用例。问题很简单:给我们一个面部图像,我们需要定位一系列特征点,即图像上的兴趣点(鼻子、左眼、嘴巴......)和标签,包括人的年龄和性别。每个界标/标签构成图像上的单独任务,并且任务之间明显相关(即,想预测左眼的位置,需要先知道右边的位置)。

一步步教你使用Head API在TensorFlow中进行多任务学习!

来自数据集的示例图像(源)。绿点是地标,每个图像还与一些其他标签相关联,包括年龄和性别。

我们将实现分为三个部分:(i)加载图像(使用tf.data和tf.image); (ii)从论文中实施卷积网络(使用TF的自定义估计量); (iii)使用Head API添加MTL逻辑。

第0步 - 加载数据集

下载数据集(http://mmlab.ie.cuhk.edu.hk/projects/TCDCN/data/MTFL.zip)后,快速检查,发现图像分为三个不同的文件夹(AFLW,lfw_5590和net_7876)。通过不同的文本文件提供训练和测试分割,每行对应一个图像和标签的路径:

一步步教你使用Head API在TensorFlow中进行多任务学习!

来自训练数据集的第一个图像和标签。蓝色数字是图像位置(从左上角开始),红色数字是类别(见下文)。

为简单起见,我们将使用Pandas加载文本文件并调整Unix标准的路径URL,例如:对于训练部分:

一步步教你使用Head API在TensorFlow中进行多任务学习!

在Pandas和scikit-learn中加载数据

由于文本文件不是很大,在这种情况下使用Pandas稍微容易一些,并且提供了一点灵活性。但是,对于较大的文件,更好的选择是直接使用tf.data对象TextLineDataset。

第1步 - 使用tf.data和Dataset对象

现在有了数据,我们可以使用tf.data加载它以使其估算好!在最简单的情况下,我们可以通过Pandas的DataFrame进行切片,就可以获取我们的数据:

一步步教你使用Head API在TensorFlow中进行多任务学习!

从Pandas的DataFrame加载tf.data中的数据

以前,将tf.data与Estimators一起使用的一个主要问题是调试数据集相当复杂,必须通过tf.Session对象。但是,从最新版本开始,即使在使用估算器时,也可以通过启用即时执行来调试数据集。例如,我们可以使用数据集构建8个元素的批次,获取第一批,并在屏幕上输出所有内容:

一步步教你使用Head API在TensorFlow中进行多任务学习!

在即时执行中调试数据集对象

现在是从路径开始加载图像的时候了!通常这不是一件容易的事,因为图像可以有许多不同的扩展、大小,有些可以是黑白,等等。幸运的是,我们可以从TF教程中获取灵感来构建一个简单的函数来封装所有这些逻辑,利用tf.image模块中的工具:

一步步教你使用Head API在TensorFlow中进行多任务学习!

使用tf.image模块解析图像

该函数负责解决大多数解析问题:

  1. 'channels'参数允许在一行中加载彩色和黑白图像;
  2. 我们将所有图像调整为我们想要的格式(40x40,根据原始文件);
  3. 在第8行,我们还标准化了我们的标签,以表示0和1之间的相对位置,而不是绝对的位置(因为我们调整了所有图像的大小,图像可能会有不同的形状)。

我们可以使用其内部的“map”函数将解析函数应用于数据集的每个元素:将它与一些用于测试的额外逻辑放在一起,我们获得最终的加载函数:

一步步教你使用Head API在TensorFlow中进行多任务学习!

从Pandas的DataFrame对象开始的完整数据加载功能

一步步教你使用Head API在TensorFlow中进行多任务学习!

从数据集成功加载单个图像

第2步 - 使用自定义估算器构建卷积网络

下一步,我们想要复制原始论文中的卷积神经网络(CNN):

一步步教你使用Head API在TensorFlow中进行多任务学习!

CNN的逻辑由两部分组成:第一部分是整个图像的通用特征提取器(在所有任务中共享),而对于每个任务,我们有一个单独的、较小的模型作用于图像的最终的特征嵌入。由于以下原因,我们将这些简单模型称为“头部”。通过梯度下降同时训练所有头部。

让我们从特征提取部分开始。为此,我们利用tf.layers对象构建我们的主网络:

一步步教你使用Head API在TensorFlow中进行多任务学习!

使用tf.layers实现特征提取部分

目前,我们将专注于单个头/任务,即估计图像中的鼻子位置。一种方法是使用自定义估算器,允许将我们自己的模型实现与标准Estimator对象的所有功能相结合。

自定义估算器的一个缺点是它们的代码往往非常“冗长”,因为我们需要将估算器的整个逻辑(训练、评估和预测)封装到一个函数中:

一步步教你使用Head API在TensorFlow中进行多任务学习!

我们的第一个自定义估算器的代码

粗略地说,模型函数接收到一个模式参数,我们可以使用它来区分我们期望做什么类型的操作(例如,训练)。反过来,模型函数通过另一个自定义对象EstimatorSpec与主Estimator对象交换所有信息:

一步步教你使用Head API在TensorFlow中进行多任务学习!

自定义估算器的示意图(源)

这不仅使代码难以阅读,而且上面的大多数代码都倾向于“样板”代码,这仅取决于我们面临的具体任务,例如,使用回归问题的均方误差。 Head API是一个实验性功能,旨在简化在这种情况下的编写代码,这是我们的下一个主题。

步骤3a - 使用Head API重写我们的自定义估算器

Head API的想法是,一旦指定了几个关键项,就可以自动生成主要预测组件(我们的模型函数):特征提取部分、损失和我们的优化算法:

一步步教你使用Head API在TensorFlow中进行多任务学习!

从某种意义上说,这与Keras的高级界面类似,但它仍然具有足够的灵活性来定义一系列更有趣的头部,我们很快就会看到。

现在,让我们重写前面的代码,这次使用“regression head”:

一步步教你使用Head API在TensorFlow中进行多任务学习!

与之前相同的模型,使用regression head

从意图和目的来看,这两个模型是等效的,但后者更具可读性并且更不容易出错,因为大多数估计器特定的逻辑现在封装在头部内部。我们可以使用估算器的“训练”界面训练两个模型中的任何一个,并开始得到我们的预测:

一步步教你使用Head API在TensorFlow中进行多任务学习!

我们的单任务模型的预测示例

请不要将Head API(位于tf.contrib中)与tf.contrib.learn.head混淆,后者已弃用。

步骤3b - 多任务学习

我们最终得到了本教程中更有趣的部分:MTL逻辑。请记住,在最简单的情况下,执行MTL相当于在同一个特征提取部分的顶部具有“多个头”,如下所示:

一步步教你使用Head API在TensorFlow中进行多任务学习!

在数学上,我们可以通过最小化任务特定损失的总和来共同优化所有任务。例如,假设我们有回归部分的损失L1(每个地标的均方误差)和分类部分的L2(不同的标记),我们可以通过梯度下降来最小化L = L1 + L2。

在这个(非常冗长的)介绍之后,您可能不会对Head API具有针对这种情况的特定头部(称为多头)感到惊讶。根据我们之前的描述,它允许线性组合源自不同磁头的多个损耗。在这一点上,我将让代码说明一切:

一步步教你使用Head API在TensorFlow中进行多任务学习!

为简单起见,我只考虑两个任务:预测鼻子位置和面部“姿势”(左侧轮廓,左侧,前侧,右侧,右侧轮廓)。我们只需要定义两个单独的头(回归一个,分类一个),并将它们与multi_head对象组合。现在添加更多头只是几行代码的问题!

为简洁起见,我们在此处省略了对输入功能的轻微修改:您可以在Colab笔记本上找到它。

此时的估算器可以使用标准方法进行训练,我们可以同时获得两个预测:

一步步教你使用Head API在TensorFlow中进行多任务学习!

我们的多任务模型的预测:节点位置和姿势(在这种情况下为正面)。

一步步教你使用Head API在TensorFlow中进行多任务学习!

相关推荐