2.8 PyTorch的计算图和自动求导机制
2.8.1 自动求导机制简介
在前面已经提到,PyTorch会根据计算过程来自动生成动态图,然后可以根据动态图的创建过程进行反向传播,计算得到每个节点的梯度值。为了能够记录张量的梯度,首先需要在创建张量的时候设置一个参数requires_grad=True,意味着这个张量将会加入到计算图中,作为计算图的叶子节点参与计算,通过一系列的计算,最后输出结果张量,也就是根节点。几乎所有的张量创建方式(如2.5.2节中介绍的四种方式)都可以指定requires_grad=True这个参数,一旦指定了这个参数,在后续的计算中得到的中间结果的张量都会被设置成requires_grad=True。对于PyTorch来说,每个张量都有一个grad_fn方法,这个方法包含着创建该张量的运算的导数信息。在反向传播过程中,通过传入后一层的神经网络的梯度,该函数会计算出参与运算的所有张量的梯度。grad_fn本身也携带着计算图的信息,该方法本身有一个next_functions属性,包含连接该张量的其他张量的grad_fn。通过不断反向传播回溯中间张量的计算节点,可以得到所有张量的梯度。一个张量的梯度张量的信息保存在该张量的grad属性中。
除PyTorch张量本身外,PyTorch提供了一个专门用来做自动求导的包,即torch.autograd。它包含有两个重要的函数,即torch.autograd.backward函数和torch.autograd.grad函数。torch.autograd.backward函数通过传入根节点张量,以及初始梯度张量(形状和当前张量的相同),可以计算产生该根节点所有对应的叶子节点的梯度。当张量为标量张量时(Scalar,即只有一个元素的张量),可以不传入初始梯度张量,默认会设置初始梯度张量为1。当计算梯度张量的时候,原先建立起来的计算图会被自动释放,如果需要再次做自动求导,因为计算图已经不存在,就会报错。如果要在反向传播的时候保留计算图,可以设置retain_graph=True。另外,在自动求导的时候默认不会建立反向传播的计算图(因为反向传播也是一个计算过程,可以动态创建计算图),如果需要在反向传播计算的同时建立和梯度张量相关的计算图(在某些情况下,如需要计算高阶导数的情况下,不过这种情况比较少),可以设置create_graph=True。对于一个可求导的张量,也可以直接调用该张量内部的backward方法来进行自动求导。
2.8.2 自动求导机制实例
下面举一个简单的例子来说明自动求导是如何使用的。根据高等数学的知识可知,如果定义一个函数f(x)=x2,则它的导数f′(x)=2x。于是可以创建一个可求导的张量来测试具体的导数,具体如代码2.23所示。
代码2.23 反向传播函数测试代码。
需要注意的一点是,张量绑定的梯度张量在不清空的情况下会逐渐累积。这种特性在某些情况下是有用的,比如,需要一次性求很多迷你批次的累积梯度,但在一般情况下,不需要用到这个特性,所以要注意将张量的梯度清零(模块和优化器都有清零参数张量梯度的函数,会在后面提到)。
2.8.3 梯度函数的使用
在某些情况下,不需要求出当前张量对所有产生该张量的叶子节点的梯度,这时可以使用torch.autograd.grad函数。这个函数的参数是两个张量,第一个张量是计算图的数据结果张量(或张量列表),第二个张量是需要对计算图求导的张量(或张量列表)。最后输出的结果是第一个张量对第二个张量求导的结果(注意最后输出的梯度会累积,和前面介绍的torch.autograd.backward函数的行为一样)。这里需要注意的是,这个函数不会改变叶子节点的grad属性,而不像torch.autograd.backward函数一样会设置叶子节点的grad属性为最后求出梯度张量。同样,torch.autograd.grad函数会在反向传播求导的时候释放计算图,如果需要保留计算图,同样可以设置retain_graph=True。如果需要反向传播的计算图,可以设置create_graph=True。
另外,有时候会碰到一种情况是求导的两个张量之间在计算图上没有关联,在这种情况下函数会报错,如果不需要函数的报错行为,可以设置allow_unused=True这个参数,结果会返回分量全为0的梯度张量(因为两个张量没有关联,所以求导的梯度为0)。
具体的torch.autograd.grad函数的使用方法可以参考代码2.24。
代码2.24 梯度函数的使用方法。
2.8.4 计算图构建的启用和禁用
由于计算图的构建需要消耗内存和计算资源,在一些情况下,计算图并不是必要的,比如神经网络的推导。在这种情况下,可以使用torch.no_grad上下文管理器,在这个上下文管理器的作用域里进行的神经网络计算不会构建任何计算图。
另外,还有一种情况是对于一个张量,我们在反向传播的时候可能不需要让梯度通过这个张量的节点,也就是新建的计算图要和原来的计算图分离。在这种情况下,可以使用张量的detach方法,通过调用这个方法,可以返回一个新的张量,该张量会成为一个新的计算图的叶子节点,新的计算图和老的计算图相互分离,互不影响,具体如代码2.25所示。
代码2.25 控制计算图产生的方法示例。