机器之心专栏
作者:Weiran Huang
本文是一篇介绍 DART+ 的专栏文章,作者们提出一种可微分的神经网络架构查找算法 DARTS+,将早停机制(early stopping)引进到原始的 DARTS[1] 算法中,不只减小了 DARTS 查找的时刻,并且极大地提高了 DARTS 的功能。相关论文《DARTS+: Improved Differentiable Architecture Search with Early Stopping》现已揭露(相关代码稍后也会开源)。
论文地址:https:///publications/DARTS+.pdf
DARTS+ 在原始 DARTS 算法基础上只需简略地参加一条早停机制,就能够在 CIFAR10、CIFAR100 和 ImageNet 上取得 2.32%、14.87% 和 23.7% 的错误率,逾越一系列现有的 DARTS 改善算法,包含 SNAS[2]、P-DARTS[3]、XNAS[4]、PC-DARTS[5] 等。
在模型巨细适当的状况下,DARTS+ 能够到达与谷歌提出的 EfficientNet[6] 相同的功能,可是查找时刻却远远小于 EfficientNet,再叠加上一些常用的 tricks,在 ImageNet 上能够到达 22.5% 的错误率!早停机制的引进,让原本在查找时刻上具有明显优势的根据「可微分」的架构查找办法,在功能上也开端逾越根据「强化学习」或「演化算法」的架构查找办法,极大地增加了「可微分架构查找」的研讨价值和运用规模。
简介
神经网络架构查找(Neural Architecture Search,NAS)在主动机器学习(AutoML)中扮演着重要的人物,近来取得越来越多的重视。用 NAS 查找得到的神经网络架构现已在多种使命上逾越了专家手艺规划的网络架构,包含物体分类、物体检测、引荐体系等。
神经网络架构查找的常见做法是首要规划一个架构查找空间,然后用某种查找战略,从中找出一个最优的网络架构。前期的计划是根据强化学习(RL)或许演化算法(Evolutionary Algorithm)来查找一个有用的网络架构,可是会消耗许多的核算资源(上千个 GPU days),不经济也不环保。后来,一些 One-Shot 的计划相继被提出,其间最具代表性的是 DARTS[1] 算法(Differentiable Architecture Search,可微分的神经网络架构查找)。它把查找空间从离散的放松到接连的,然后能够用梯度下降来一同查找架构和学习权重。详细来说,DARTS 运用了如下的两层优化(Bi-Level Optimization)来查找:
Bi-Level Optimization in DARTS
其间,alpha 是架构的参数,w 是 alpha 对应的模型权重。前者运用 validation data 来进行更新,后者运用 training data 来进行更新。详细细节能够参看 DARTS 的原文。DARTS 成功把查找时刻从上千个 GPU days 削减到了几个 GPU days。
DARTS 算法的问题
DARTS 算法有一个严峻的问题,便是当查找轮数过大时,查找出的架构中会包含许多的 skip-connect,然后功能会变得很差。咱们把这个现象叫做 Collapse of DARTS。
举个比方,让咱们来考虑在 CIFAR100 上用 DARTS 做查找。从下图能够看出,当 search epoch(横轴)比较大的时分,skip-connect 的 alpha 值(绿线)将变得很大。
Alpha Values in The Shallowest Edge
因而,在 DARTS 最终选出的网络架构中,skip-connect 的数量也会跟着 search epoch 变大而越来越多,如下图中的绿线所示。
在一个节点数固定的 cell 中,skip-connect 的数量越多,会导致网络变得越浅。比较于深度网络,浅度网络可学习的参数更少,具有的表达能力更弱。因而,在 DARTS 搜出的网络架构中,skip-connect 的数量太多会导致功能急剧变差。例如,在上图中,当 skip-connect 的数量逾越 2 个的时分,网络的功能(蓝线)开端下降。下图直观展现了跟着 search epoch 变大,网络结构由深变浅的进程。
不同 search epoch 的景象下,在 CIFAR100 上用 DARTS 挑选出的网络结构图
DARTS 发作 Collapse 背面的原因是在两层优化中,alpha 和 w 的更新进程存在先协作(cooperation)后竞赛(competition)的问题。大略来说,在刚开端更新的时分,alpha 和 w 是一同被优化,然后 alpha 和 w 都是越变越好。渐渐地,两者开端变成竞赛联络,因为 w 在竞赛中比 alpha 更有优势(比方,w 的参数量大于 alpha 的参数量,One-Shot 模型在大多数 alpha 下都能收敛,等等),alpha 开端被按捺,因而网络架构呈现了先变好后变差的成果,也便是上上图中蓝线的状况。
详细来说,在查找进程的初始阶段,One-Shot 模型欠拟合到数据集,因而在查找进程刚开端的时分,alpha 和 w(也便是 One-Shot 模型的参数)都会朝着变好的方向更新,这便是协作的阶段。因为整个 One-Shot 模型中,前面的 cell 比后边的 cell 能接触到更洁净的数据,假如咱们答应不同的 cell 能够具有不同的网络结构(打破 DARTS 中 cell 同享网络结构的设定),那么前面的 cell 会比后边的 cell 更快地学到特征。
一旦前面的 cell 现已学到了不错的特征表达,然后边的 cell 学到的特征表达相对较差,那么后边的 cell 接下来会倾向于挑选 skip-connect,来把前面 cell 现已学好的特征表达直接传递到后边。下图是打破 DARTS 中 cell 同享网络结构的设定下,搜出来的网络结构图:能够看到,前面的 cell 大部分都是卷积算子,而靠后的 cell 大部分都是 skip-connect。
打破 cell 同享网络结构的设定下,不同方位的 cell 搜出来的网络结构图
回到 DARTS 的设定,假如咱们强制不同的 cell 同享同一个网络结构,那么 skip-connect 就会从后边的 cell 分散到前面的 cell。当 skip-connect 开端明显变多的时分,协作的阶段就转向了竞赛的阶段:alpha 开端变坏,DARTS 开端 collapse。
值得一提的是,两层优化中的协作和竞赛现象在其他运用中(比方 GAN,meta-learning 等)也有被观察到。以 GAN 为例,一个学好的 discriminator 对练习一个 generator 是至关重要的 [7],这是 generator 和 discriminator 之间的协作;当输入数据(fake 或 real)落在低维流形上一同 discriminator 过参数化的时分,discriminator 很简略把生成的 fake data 从 real data 中区分开来,一同 generator 也会因为发作梯度消失导致无法生成 real data[8],这是 generator 和 discriminator 之间的竞赛。
DARTS+:引进早停机制
为了处理 DARTS 会 collapse 的问题,避免 skip-connect 发作过多,咱们提出一种非常简略并且行之有用的早停机制,改善后的 DARTS 算法称之为 DARTS+ 算法。本文中咱们依然遵从 DARTS 中 cell 同享网络结构的设定,将探究怎么打破 cell 网络结构同享留为 future work。
早停原则:当一个 cell 中呈现两个及两个以上的 skip-connect 的时分,查找进程中止。
DARTS+ 最大的长处便是操作起来非常简略。比较于其他改善 DARTS 的算法,DARTS+ 只需求一点点改动就能够明显地进步功能,一同还能直接削减查找时刻。
上图中的红圈代表各个可学习算子(比方卷积)的 alpha 排序不再改动的时刻点(详细细节请参看原文)。
因为 alpha 值最大的可学习算子对应最终的网络会挑选的算子,当 alpha 排序稳守时,这个算子在最终挑选的网络不会呈现改动,这说明 DARTS 的查找进程现已充沛。从上图中蓝线也能看出,当过了红圈之后,架构的功能开端呈现下降,然后呈现 collapse 问题。因而,咱们能够挑选在可学习算子 alpha 排序不再改动(图中红圈处)的时刻点邻近早停。当早停原则满意时(左图中赤色虚线),根本处于 DARTS 查找充沛处,因而在早停原则处中止查找能够有用避免 DARTS 发作 collapse。
经过上面的剖析,咱们能够给出一个稍杂乱但更为直接的早停原则:
早停原则*:当各个可学习算子(比方卷积)的 alpha 排序满足安稳(比方 10 个 epoch 坚持不变)的时分,查找进程中止。
咱们指出,第一个早停原则更便于操作,而当需求更精准的中止或许引进其他的查找空间的时分,咱们能够用早停原则* 来替代。因为早停机制处理了 DARTS 查找中固有存在的问题,因而,它也能够被用在其它根据 DARTS 的算法中来协助进步进一步功能。
值得一提的是,近来的一些根据 DARTS 改善的算法其实也隐式地运用了早停的主意。
P-DARTS[3] 运用了:1)搜 25 个 epoch 来替代本来的 50 个 epoch,2)在 skip-connects 之后加 dropout,3)手动把 skip-connects 的数目减到 2。
Auto-DeepLab[9] 运用了 20 个 epoch 来训架构参数 alpha,一同发现更多的 epoch(60,80,100)对功能没有优点。
PC-DARTS[5] 运用部分通道连接来下降查找时刻,因而查找收敛需求引进更多的 epoch,然后依然查找 50 个 epoch 便是一个隐式的早停机制。
试验验证
咱们在 CIFAR10[10]、CIFAR100[10]、Tiny-ImageNet-200[11] 和 ImageNet[12] 上分类问题进行验证。在试验中,咱们默许运用第一个早停原则。详细的完成细节,请参看原文。
试验成果如下:
DARTS+ 在 CIFAR10、CIFAR100 和 ImageNet 上取得 2.32%、14.87% 和 23.7% 的错误率,逾越一系列现有的 DARTS 改善算法,包含 SNAS[2]、P-DARTS[3]、XNAS[4]、PC-DARTS[5] 等。在模型巨细适当的状况下,DARTS+ 能够到达与谷歌提出的 EfficientNet-B0[6] 相同的功能,可是查找时刻却远远小于 EfficientNet。假如再叠加 SE 模块,mixup 等,在 ImageNet 上能够到达 22.5% 的错误率。
详细的功能指标如下所示:
CIFAR10 和 CIFAR100 上的试验成果
Tiny-ImageNet-200 上的试验成果
ImageNet 上的试验成果
结语
综上所述,DARTS+ 简略高雅地处理了 DARTS 算法中固有的 collapse 问题,经过引进操作起来非常简略的早停机制,既缩短了查找时刻,又极大地进步了功能。想要进一步提高 DARTS 的功能,一个可行的方向是考虑打破 DARTS 中「不同 cell 同享网络架构」的设置。
本文为机器之心专栏,转载请联络本大众号取得授权。
------------------------------------------------