前篇:https://welts.xyz/2022/04/18/naive_graph/。
代码地址:https://github.com/Kaslanarian/PyAdNet/blob/main/naive_example.py。
在那里,我们构建了标量级的静态计算图。静态计算图的宗旨是先建图再计算。这里的计算包括前向传播和反向传播。深度学习框架,比如TensorFlow1和Theano都采用的是静态图。静态图的一个很反直觉的设定就是,调用计算函数后,用户无法得到计算的结果,因为这种计算函数的目的是建图,而不是计算。所有的计算必须要等到最后的前向传播才能进行,以我们在前面设计的计算图为例:
a = NaiveGraph.Variable(1., name='a')
b = NaiveGraph.Variable(2., name='b')
c = a + b
print(c.get_value()) # 此处c不是3,因为没有前向传播
NaiveGraph.forward()
print(c.get_value()) # 此处c才是3,因为已经前向传播
我们希望c
在计算后能够马上得到结果。这就是动态计算图的思想,PyTorch正是基于动态计算图的深度学习框架。所以我们才能够在PyTorch看到这种操作:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
net = Net()
接着在训练中多次调用net.forward
方法,如果是静态计算图,net.forward
实际上只是建图函数,重复调用会创建多个重复的节点。但在动态计算图中,net.forward
会在构建计算图后马上求值。剩下的问题是,按我们前面说的,动态计算图的求值仍逃不过“建图”这个步骤,而一旦多次求值,计算图中仍然会出现重复的冗余节点。这就涉及到动态计算图所谓“动态”的第二个特性: 反向传播之后销毁计算图。所以对于训练神经网络而言,静态计算图的步骤是:
而动态计算图则是:
所以PyTorch比TensorFlow更灵活,更Pythonic;而TensorFlow难懂,但效率高。这也解释了“学术用PyTorch,企业用TensorFlow”这句话的由来。
下面就是动态计算图的实现方面的探究。
上图是PyTorch建图并求值,然后反向传播的代码与流程。动图的循环播放可以视作前向传播——反向传播的循环。可以发现有两种变量一直存在,一种是用户自定义的输入变量,一种是我们要更新的参数。除此之外,运算过程中产生的节点,我们都会进行删除。
我们考虑PyTorch中Tensor的requires_grad
属性,表明该张量是否需要进行求导,默认是False。我们可以将需要求导的张量视为变量,不需要求导的张量视为常量。由此,PyTorch中的一个很好理解的规定是:一个运算,它的参数(操作数)只要有一个的requires_grad
为True,那么运算结果也是需要求导的张量;如果参数(操作数)全部不需要求导,那么它的运算结果也不需要求导。
在反向传播的过程中,梯度流不需要流到不需要求导的节点。从这点考虑,PyTorch不会将这些节点放到计算图中。于是,在我们日常使用PyTorch时,类似下面的语句对计算图无影响:
import torch
x = torch.zeros((1, 2))
y = torch.ones((1, 2))
z = x + y
即计算图中不会出现x, y, z
,因为它们不需要求导。
对的,x不配,因为我们不需要计算x的梯度。动态计算图只显示需要计算梯度的变量以及它们所依赖的变量。——手把手PyTorch 深度学习 (3): 动态计算图和优化器
所以,PyTorch在反向传播后,对计算图进行销毁,实际上是将满足下列条件的节点删除
requires_grad
标签为True(实际上是节点在计算图中的必要条件);这里的删除包括:
这样它成为孤立节点,只是一个存储值的节点。如果我们对它的值感兴趣,完全可以将它的值用在后面的运算中,甚至以新节点的身份重返计算图。
最后再提一下PyTorch反向传播的一个细节。默认的反向传播后即销毁计算图的机制,在多轮训练神经网络等场景下显得累赘。PyTorch为节点的backward
方法提供retain_graph
参数,当它为True时,反向传播后会保留计算图,而不去销毁。比如下面代码
import torch
x = torch.randn(3, 4, requires_grad=True)
y = x**2
output = y.mean()
output.backward()
output.backward()
在执行第二次反向传播时会报错:
RuntimeError: Trying to backward through the graph a second time,
but the saved intermediate results have already been freed.
Specify retain_graph=True when calling .backward()
or autograd.grad() the first time.
这就是因为计算图已经被销毁了,而一旦指定retain_graph
参数为True:
import torch
x = torch.randn(3, 4, requires_grad=True)
y = x**2
output = y.mean()
output.backward(retain_graph=True)
output.backward()
程序则不会报错。我们在后面会尝试模仿PyTorch的该机制。
在静态计算图的基础上,我们打算实现一个基于标量的动态计算图,语法上更靠近PyTorch。首先是节点的设计,这里我们不再将节点区分为常量,变量和占位符了,而是统一的Node类:
class Node:
def __init__(self, value, requires_grad=False) -> None:
# 生成唯一的id
while True:
new_id = randint(0, 1000)
if new_id not in NaiveGraph.id_list:
break
self.id: int = new_id
self.value = float(value)
self.requires_grad = requires_grad
# grad和grad_fn分别是节点梯度和节点对应的求导函数
# 借鉴PyTorch
self.grad = 0. if self.requires_grad else None
self.grad_fn = None
# 默认是操作符名,该属性为绘图需要
self.name = None
self.next = list()
self.last = list()
# 由于不需要前向传播,所以入度属性被淘汰
self.out_deg, self.out_deg_com = 0, 0
if self.requires_grad:
# 不需要求梯度的节点不出现在动态计算图中
NaiveGraph.add_node(self)
接着是计算函数的改变,我们坚持了模块化的思想,继续沿用一元函数和二元函数框架:
@classmethod
def unary_function_frame(cls, node, operator: str):
if type(node) != cls.Node:
node = cls.Node(node)
# grad_fn_table是一个字符串——函数元组字典,元组中是求值函数和求导函数
fn, grad_fn = cls.grad_fn_table.get(operator)
# 这里fn(node)说明我们直接计算输出,即动态计算图的特征
operator_node = cls.Node(fn(node), node.requires_grad)
operator_node.name = operator
if operator_node.requires_grad:
# 可求导的节点才可有grad_fn成员
operator_node.grad_fn = grad_fn
# 只有可求导的变量间才会用有向边联系
node.build_edge(operator_node)
return operator_node
@classmethod
def binary_function_frame(cls, node1, node2, operator: str):
if type(node1) != cls.Node:
node1 = cls.Node(node1)
if type(node2) != cls.Node:
node2 = cls.Node(node2)
fn, grad_fn = cls.grad_fn_table.get(operator)
# 两个输入只要有一个是变量,输出就是变量
requires_grad = node1.requires_grad or node2.requires_grad
operator_node = cls.Node(
fn(node1, node2), # 直接计算
requires_grad=requires_grad,
)
operator_node.name = operator
if requires_grad:
operator_node.grad_fn = grad_fn
node1.build_edge(operator_node)
node2.build_edge(operator_node)
return operator_node
最后是反向传播步骤,我们抛弃了之前那种基于整个图的反向传播,而是只留出针对特定节点的接口,所以backward
成为了Node类的成员方法:
def backward(self, retain_graph=False):
if self not in NaiveGraph.node_list:
print("AD failed because the node is not in graph")
return
node_queue = []
self.grad = 1.
for node in NaiveGraph.node_list:
if node.requires_grad:
if node.out_deg == 0:
node_queue.append(node)
while len(node_queue) > 0:
node = node_queue.pop()
for last_node in node.last:
last_node.out_deg -= 1
last_node.out_deg_com += 1
if last_node.out_deg == 0 and last_node.requires_grad:
# 加入节点是需要求导的这一条件
for n in last_node.next:
last_node.grad += n.grad * n.grad_fn(n, last_node)
node_queue.insert(0, last_node)
if retain_graph:
# 保留图
for node in NaiveGraph.node_list:
node.out_deg += node.out_deg_com
node.out_deg_com = 0
else:
# 释放计算图:删除所有非叶子节点
new_list = [] # 新的计算图节点集
for node in NaiveGraph.node_list:
if len(node.last) == 0:
# is leaf
new_list.append(node)
else:
# 清除节点信息
node.next.clear()
node.last.clear()
node.in_deg = 0
node.next.clear()
node.out_deg = 0
node.out_deg_com = 0
NaiveGraph.node_list = new_list
类似的,我们想依靠动态计算图进行求导,求Jacobi矩阵和梯度下降法测试,即可微分计算图的实现中的例子,实验结果和之前静态计算图中的相同,说明我们的动态计算图构建是正确的。