Python元学习:通用人工智能的实现
上QQ阅读APP看书,第一时间看更新

1.3 通过梯度下降来学习如何通过梯度下降来学习

现在,我们来看一个有趣的元学习算法——通过梯度下降来学习如何通过梯度下降来学习。这个名字是不是有点吓人?实际上,它是最简单的元学习算法之一。我们知道,在元学习中,目标是学习学习的过程。一般如何训练神经网络呢?答案是通过梯度下降来计算损失和最小化损失以训练网络。因此,我们使用梯度下降法来优化模型。如果不使用梯度下降,我们能自动学习这个优化过程吗?

但是应该如何学习呢?我们用递归神经网络(Recurrent Neural Network, RNN)代替传统的梯度下降优化器。这是如何实现的呢?如何用RNN代替梯度下降法?如果你仔细观察梯度下降的行为,就会发现,它基本上是一个从输出层到输入层的更新序列。我们将这些更新存储在一个状态中,这样就可以使用RNN并将更新存储在RNN单元中。

该算法的主要思想是用RNN代替梯度下降法。但问题是RNN如何学习?如何优化RNN?为了优化RNN,我们使用梯度下降法。简而言之,我们正在学习通过RNN来执行梯度下降,而这个RNN是通过梯度下降来优化的。这就是该算法名称的由来。

我们称RNN为优化器,称基网络(base network)为优化对象。假设有一个由参数θ影响的模型f。需要找到这个最优参数θ,以将损失最小化。一般情况下,通过梯度下降法来寻找最优参数,但现在用RNN来寻找最优参数。因此,RNN(优化器)找到了最优参数并将其发送给优化对象(基网络)。优化对象使用这个参数,计算损失,并将损失发送给RNN。基于该损失,RNN通过梯度下降优化自身,并更新模型参数θ

感到困惑吗?请看图1-1:优化对象(基网络)是通过优化器(RNN)优化的。优化器将更新后的参数(即权重)发送给优化对象,优化对象使用这些权重计算损失,并将损失发送给优化器。基于损失,优化器通过梯度下降来改进自身。

图1-1

假设基网络(优化对象)以θ作为参数,而RNN(优化器)以ϕ作为参数。优化器的损失函数是什么?我们知道优化器(RNN)用于减少优化对象(基网络)的损失。因此,优化器的损失是优化对象的平均损失,它可以表示为

L (ϕ)=Ef[f(θ(f, ϕ))]

怎样才能把损失降到最低呢?通过梯度下降找到合适的ϕ来最小化这种损失。RNN接受什么作为输入?它又输出什么呢?优化器,也就是RNN,将优化对象的梯度∇t以及它的上一个状态ht作为输入,并返回输出——可以最小化优化器损失的更新gt。我们用函数m来表示RNN:

(gt, ht+1)=m(∇t, ht, ϕ)

以上方程的参数解释如下:

❑ ∇t是模型f(优化对象)的梯度,即∇t=∇tf(θt);

ht是RNN的隐藏状态;

ϕ是RNN的参数;

❑ 输出gtht 1+ 分别是(提供给优化器的)更新与RNN的下一个状态。

于是,可以使用θt+1=θt+gt来更新模型参数值。

如图1-2所示,优化器m在时间t处,以隐藏状态ht和相对于θt的梯度∇t为输入,计算出gt并将其发送到优化对象,再与θt相加,成为下一步更新的θt+1

图1-2

由此,我们通过梯度下降学会了梯度下降优化。