动态计算图的实现

以及PyTorch的内部原理浅析

Posted by Welt Xing on April 21, 2022

引言

前篇:https://welts.xyz/2022/04/18/naive_graph/

代码地址:https://github.com/Kaslanarian/PyAdNet/blob/main/naive_example.py

在那里,我们构建了标量级的静态计算图。静态计算图的宗旨是先建图再计算。这里的计算包括前向传播和反向传播。深度学习框架,比如TensorFlow1和Theano都采用的是静态图。静态图的一个很反直觉的设定就是,调用计算函数后,用户无法得到计算的结果,因为这种计算函数的目的是建图,而不是计算。所有的计算必须要等到最后的前向传播才能进行,以我们在前面设计的计算图为例:

a = NaiveGraph.Variable(1., name='a')
b = NaiveGraph.Variable(2., name='b')
c = a + b
print(c.get_value()) # 此处c不是3,因为没有前向传播
NaiveGraph.forward()
print(c.get_value()) # 此处c才是3,因为已经前向传播

我们希望c在计算后能够马上得到结果。这就是动态计算图的思想,PyTorch正是基于动态计算图的深度学习框架。所以我们才能够在PyTorch看到这种操作:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

net = Net()

接着在训练中多次调用net.forward方法,如果是静态计算图,net.forward实际上只是建图函数,重复调用会创建多个重复的节点。但在动态计算图中,net.forward会在构建计算图后马上求值。剩下的问题是,按我们前面说的,动态计算图的求值仍逃不过“建图”这个步骤,而一旦多次求值,计算图中仍然会出现重复的冗余节点。这就涉及到动态计算图所谓“动态”的第二个特性: 反向传播之后销毁计算图。所以对于训练神经网络而言,静态计算图的步骤是:

  • 一次建图+$n$次前向传播+$n$次反向传播;

而动态计算图则是:

  • $n$次前向传播/建图+$n$次反向传播。

所以PyTorch比TensorFlow更灵活,更Pythonic;而TensorFlow难懂,但效率高。这也解释了“学术用PyTorch,企业用TensorFlow”这句话的由来。

进一步分析

下面就是动态计算图的实现方面的探究。

dynamic

上图是PyTorch建图并求值,然后反向传播的代码与流程。动图的循环播放可以视作前向传播——反向传播的循环。可以发现有两种变量一直存在,一种是用户自定义的输入变量,一种是我们要更新的参数。除此之外,运算过程中产生的节点,我们都会进行删除。

我们考虑PyTorch中Tensor的requires_grad属性,表明该张量是否需要进行求导,默认是False。我们可以将需要求导的张量视为变量,不需要求导的张量视为常量。由此,PyTorch中的一个很好理解的规定是:一个运算,它的参数(操作数)只要有一个的requires_grad为True,那么运算结果也是需要求导的张量;如果参数(操作数)全部不需要求导,那么它的运算结果也不需要求导。

在反向传播的过程中,梯度流不需要流到不需要求导的节点。从这点考虑,PyTorch不会将这些节点放到计算图中。于是,在我们日常使用PyTorch时,类似下面的语句对计算图无影响:

import torch

x = torch.zeros((1, 2))
y = torch.ones((1, 2))
z = x + y

即计算图中不会出现x, y, z,因为它们不需要求导。

对的,x不配,因为我们不需要计算x的梯度。动态计算图只显示需要计算梯度的变量以及它们所依赖的变量。——手把手PyTorch 深度学习 (3): 动态计算图和优化器

所以,PyTorch在反向传播后,对计算图进行销毁,实际上是将满足下列条件的节点删除

  1. requires_grad标签为True(实际上是节点在计算图中的必要条件);
  2. 由运算产生,即存在指向该节点的其他节点。

这里的删除包括:

  1. 将该节点踢出计算图节点集合;
  2. 删除该节点的前驱和后继信息。

这样它成为孤立节点,只是一个存储值的节点。如果我们对它的值感兴趣,完全可以将它的值用在后面的运算中,甚至以新节点的身份重返计算图。

最后再提一下PyTorch反向传播的一个细节。默认的反向传播后即销毁计算图的机制,在多轮训练神经网络等场景下显得累赘。PyTorch为节点的backward方法提供retain_graph参数,当它为True时,反向传播后会保留计算图,而不去销毁。比如下面代码

import torch

x = torch.randn(3, 4, requires_grad=True)
y = x**2
output = y.mean()
output.backward()
output.backward()

在执行第二次反向传播时会报错:

RuntimeError: Trying to backward through the graph a second time, 
but the saved intermediate results have already been freed.
Specify retain_graph=True when calling .backward() 
or autograd.grad() the first time.

这就是因为计算图已经被销毁了,而一旦指定retain_graph参数为True:

import torch

x = torch.randn(3, 4, requires_grad=True)
y = x**2
output = y.mean()
output.backward(retain_graph=True)
output.backward()

程序则不会报错。我们在后面会尝试模仿PyTorch的该机制。

动态计算图的实现

在静态计算图的基础上,我们打算实现一个基于标量的动态计算图,语法上更靠近PyTorch。首先是节点的设计,这里我们不再将节点区分为常量,变量和占位符了,而是统一的Node类:

class Node:
    def __init__(self, value, requires_grad=False) -> None:
        # 生成唯一的id
        while True:
            new_id = randint(0, 1000)
            if new_id not in NaiveGraph.id_list:
                break
        self.id: int = new_id
        self.value = float(value)
        self.requires_grad = requires_grad
        # grad和grad_fn分别是节点梯度和节点对应的求导函数
        # 借鉴PyTorch
        self.grad = 0. if self.requires_grad else None
        self.grad_fn = None
        # 默认是操作符名,该属性为绘图需要
        self.name = None
        self.next = list()
        self.last = list()
        # 由于不需要前向传播,所以入度属性被淘汰
        self.out_deg, self.out_deg_com = 0, 0
        if self.requires_grad:
            # 不需要求梯度的节点不出现在动态计算图中
            NaiveGraph.add_node(self)

接着是计算函数的改变,我们坚持了模块化的思想,继续沿用一元函数和二元函数框架:

@classmethod
def unary_function_frame(cls, node, operator: str):
    if type(node) != cls.Node:
        node = cls.Node(node)

    # grad_fn_table是一个字符串——函数元组字典,元组中是求值函数和求导函数
    fn, grad_fn = cls.grad_fn_table.get(operator)
    # 这里fn(node)说明我们直接计算输出,即动态计算图的特征
    operator_node = cls.Node(fn(node), node.requires_grad)
    operator_node.name = operator
    if operator_node.requires_grad:
        # 可求导的节点才可有grad_fn成员
        operator_node.grad_fn = grad_fn
        # 只有可求导的变量间才会用有向边联系
        node.build_edge(operator_node)
    return operator_node

@classmethod
def binary_function_frame(cls, node1, node2, operator: str):
    if type(node1) != cls.Node:
        node1 = cls.Node(node1)
    if type(node2) != cls.Node:
        node2 = cls.Node(node2)

    fn, grad_fn = cls.grad_fn_table.get(operator)
    # 两个输入只要有一个是变量,输出就是变量
    requires_grad = node1.requires_grad or node2.requires_grad
    operator_node = cls.Node(
        fn(node1, node2), # 直接计算
        requires_grad=requires_grad,
    )
    operator_node.name = operator
    if requires_grad:
        operator_node.grad_fn = grad_fn
        node1.build_edge(operator_node)
        node2.build_edge(operator_node)
    return operator_node

最后是反向传播步骤,我们抛弃了之前那种基于整个图的反向传播,而是只留出针对特定节点的接口,所以backward成为了Node类的成员方法:

def backward(self, retain_graph=False):
    if self not in NaiveGraph.node_list:
        print("AD failed because the node is not in graph")
        return

    node_queue = []
    self.grad = 1.

    for node in NaiveGraph.node_list:
        if node.requires_grad:
            if node.out_deg == 0:
                node_queue.append(node)

    while len(node_queue) > 0:
        node = node_queue.pop()
        for last_node in node.last:
            last_node.out_deg -= 1
            last_node.out_deg_com += 1
            if last_node.out_deg == 0 and last_node.requires_grad:
                # 加入节点是需要求导的这一条件
                for n in last_node.next:
                    last_node.grad += n.grad * n.grad_fn(n, last_node)
                node_queue.insert(0, last_node)

    if retain_graph:
        # 保留图
        for node in NaiveGraph.node_list:
            node.out_deg += node.out_deg_com
            node.out_deg_com = 0
    else:
        # 释放计算图:删除所有非叶子节点
        new_list = [] # 新的计算图节点集
        for node in NaiveGraph.node_list:
            if len(node.last) == 0:
                # is leaf
                new_list.append(node)
            else:
                # 清除节点信息
                node.next.clear()
                node.last.clear()
                node.in_deg = 0
            node.next.clear()
            node.out_deg = 0
            node.out_deg_com = 0
        NaiveGraph.node_list = new_list

例子

类似的,我们想依靠动态计算图进行求导,求Jacobi矩阵和梯度下降法测试,即可微分计算图的实现中的例子,实验结果和之前静态计算图中的相同,说明我们的动态计算图构建是正确的。