Python元学习:通用人工智能的实现
上QQ阅读APP看书,第一时间看更新

1.4 少样本学习的优化模型

我们知道,少样本学习基于较少的数据点,那么如何将梯度下降应用到少样本学习中呢?在少样本学习中,梯度下降会由于数据点非常少而突然失效。梯度下降优化需要更多的数据点来达到收敛和损失最小化。因此,在少样本学习中需要一种更好的优化技术。假设有一个由参数θ影响的模型f。我们用一些随机值来初始化参数θ,并尝试使用梯度下降法找到最优值。让我们回忆一下梯度下降的更新方程:

以上方程的参数解释如下:

θt是更新参数;

θt-1是上一步的参数值;

αt是学习率;

是相对于θt-1的损失函数的梯度。

梯度下降的更新方程是不是看起来很熟悉?是的,你猜对了,它类似于长短期记忆网络(LSTM)的细胞状态更新方程,可以写成:

可以将LSTM细胞更新方程与梯度下降完全对应起来,设ft =1,可得:

因此,在少样本学习中,可以使用LSTM而非梯度下降作为优化器。LSTM是元学习器,它将学习用于训练模型的更新规则。因此,我们使用两个网络:一个是基学习器,它学会执行任务;另一个是元学习器,它试图找到最优的参数。这是如何实现的呢?

我们知道,LSTM使用遗忘门(forget gate)来丢弃存储器中不需要的信息,它可以表示为

ft=σ(wf⋅[ht-1,xt]+bf)

这个遗忘门在我们的优化场景中有什么用呢?假设我们处在一个损失很大,梯度接近于零的位置。怎样才能摆脱这种局面呢?在这种情况下,可以收缩模型的参数,并忘记其前一个值的某些部分。我们可以使用遗忘门来实现这一点,它以当前参数值θt-1、当前损失Lt、当前梯度以及前一个遗忘门作为输入。它可以表示为

下面来看看输入门(input gate)。我们知道LSTM中的输入门是用来决定更新什么值的,它可以表示为

it=σ(wi⋅[ht-1, xt]+bi)

在少样本学习中,可以使用这个输入门来调整学习率,从而在防止发散的同时快速学习:

因此,元学习器在多次更新之后得到了itft的最优值。

可是,这是如何运作的呢?

假设有一个由θ影响的基网络M、由ϕ影响的LSTM元学习器R,以及数据集D。我们将数据集分割为训练集Dtrain和测试集Dtest。首先随机初始化元学习器参数ϕ

T次迭代中,随机从Dtrain中抽取数据点,计算损失以及相对于模型参数θ的损失梯度。将这个梯度、损失和元学习器参数ϕ提供给元学习器。元学习器R会返回细胞状态ct,然后在时间t将基网络M的参数θt更新为ct。重复N次,如图1-3所示。

图1-3

因此,经过T次迭代,我们会得到一个最优参数θT。不过如何检查θT的性能并更新元学习器参数呢?使用测试集和参数θT计算测试集的损失。然后,计算相对于元学习器参数ϕ的损失梯度,并更新ϕ,如图1-4所示。

图1-4

迭代n次,并更新元学习器。完整的算法如图1-5所示。

图1-5