动态规划入门——动态规划与数据结构的结合,在树上做DP

本文由TechFlow原创,本博文仅作为知识点学习,不会用于任何商业用途!


今天我们来看一个有趣的问题,通过这个有趣的问题,我们来了解一下在树形结构当中做动态规划的方法。

这个问题题意很简单,给定一棵树,并不一定是二叉树,树上的树枝带有权重,可以看成是长度。要求树上最长的链路的长度是多少?

比如我们随手画一棵树,可能丑了点,勿怪:

如果让我们用肉眼来看,稍微尝试一下就能找到答案,最长的路径应该是下图当中红色的这条:

但是如果让我们用算法来算,应该怎么办呢?

这道题其实有一个非常巧妙的办法,我们先不讲,先来看看动态规划怎么解决这个问题。

树形DP

动态规划并不只是可以在数组当中运行,实际上只要满足动态规划的状态转移的条件和无后效性就可以使用动态规划,无论在什么数据结构当中。树上也是一样的,明白了这点之后,就只剩下了两个问题,第一个是状态是什么,第二个问题是状态之间怎么转移?

在之前的背包问题当中,状态就是背包当前用的体积,转移呢就是我们新拿一个物品的决策。但是这一次我们要在树上进行动态规划,相对来说状态和对应的转移会隐蔽一些。没有关系,我会从头开始整理思路,一点一点将推导和思考的过程讲解清楚。

首先,我们都知道,状态之间转移其实本质上是一个由局部计算整体的过程。我们通过相对容易的子状态进行转移,得到整体的结果。这个是动态规划的精髓,某种程度上来说它和分治法也比较接近,都存在大问题和小问题之间逻辑上的关系。所以当我们面临一个大问题一筹莫展的时候,可以借鉴一下分治法,思考一下从小问题入手。

所以,我们从小到大,由微观到宏观,来看看最简单的情况:

这种情况很明显,链路只有一条,所以长度自然是5 + 6 = 11,这显然也是最长的长度。这种情况都没有问题,下面我们来把情况稍微再变得复杂一些,我们在树上多加入一层:

这张图稍微复杂了一些,但是路径也不难找到,应该是E-B-F-H。路径的总长度为12:

但是如果我们变更一下路径长度呢,比如我们把FG和FH的路径加长,会得到什么结果呢?

显然这种情况下答案就变了,FGH是最长的。

举这个例子只为了说明一个很简单的问题,即对于一棵树而言它上面的最长路径并不一定经过根节点。比如刚才的例子当中,如果路径必须要经过B的话,最长只能构造出4+2+16=22的长度,但是如果可以不用经过B的话,可以得到最长的长度是31。

得出这个结论看似好像没有用,但其实对于我们理清思路很有帮助。既然我们不能保证最长路径一定会经过树根,所以我们就不能直接转移答案。那我们应该怎么办呢?

回答这个问题光想是不够的,依然需要我们来观察问题和深入思考。

转移过程

我们再观察一下下面这两张图:

有没有发现什么规律?

由于我们的数据结构就是树形的,所以这个最长路径不管它连通的哪两个节点,一定可以保证,它会经过某一棵子树的根节点。不要小看这个不起眼的结论,实际上它非常重要。有了这个结论之后,我们将整条路径在根节点处切开。

切开之后我们得到了两条通往叶子节点的链路,问题来了,根节点通往叶子节点的链路有很多条,为什么是这两条呢?

很简单,因为这两条链路最长。所以这样加起来之后就可以保证得到的链路最长。这两条链路都是从叶子节点通往A的,所以我们得到的最长链路就是以A为根节点的子树的最长路径。

我们前面的分析说了,最长路径是不能转移的,但是到叶子的最长距离是可以转移的。我们举个例子:

F到叶子的最长距离显然就是5和6中较大的那个,B稍微复杂一些,D和E都是叶子节点,这个容易理解。它还有一个子节点F,对于F来说它并不是叶子节点,但是我们前面算到了F到叶子节点的最长距离是6,所以B通过F到叶子节点的最长距离就是2 + 6 = 8。这样我们就得到了状态转移方程,不过我们转移的不是要求的答案而是从当前节点到叶子节点的最长距离和次长距离

因为只有最长距离是不够的,因为我们要将根节点的最长距离加上次长距离得到经过根节点的最长路径,由于我们之前说过,所有的路径必然经过某棵子树的根节点。这个想明白了是废话,但是这个条件的确很重要。既然所有的链路都至少经过某一个子树的根节点,那么我们算出所有子树经过根节点的最长路径,其中最长的那个不就是答案么?

下面我们演示一下这个过程:

上图当中用粉色笔标出的就是转移的过程,对于叶子节点来说最长距离和次长距离都是0,主要的转移过程发生在中间节点上。

转移的过程也很容易想通,对于中间节点i,我们遍历它所有的子节点j,然后维护最大值和次大值,我们写下状态转移方程:

状态转移想明白了,剩下的就是编码的问题了。可能在树上尤其是递归的时候做状态转移有些违反我们的直觉,但实际上并不难,我们写出代码来看下,我们首先来看建树的这个部分。为了简化操作,我们可以把树上所有的节点序号看成是int,对于每一个节点,都会有一个数组存储所有与这个节点连接的边,包括父亲节点。

由于我们只关注树上的链路的长度,并不关心树的结构,树建好了之后,不管以哪一个点为整体的树根结果都是一样的。所以我们随便找一个节点作为整棵树的根节点进行递归即可。强调一下,这个是一个很重要的性质,因为本质上来说,树是一个无向无环全连通图。所以不管以哪个节点为根节点都可以连通整棵子树。

我们创建一个类来存储节点的信息,包括id和两个最长以及次长的长度。我们来看下代码,应该比你们想的要简单得多。

class Node(object):
    def __init__(self, id):
        self.id = id
        # 以当前节点为根节点的子树到叶子节点的最长链路
        self.max1 = 0
        # 到叶子节点的次长链路
        self.max2 = 0
        # 与当前节点相连的边
        self.edges = []

    # 添加新边
    def add_edge(self, v, l):
        self.edges.append((v, l))


# 创建数组,存储所有的节点
nodes = [Node(id) for id in range(12)]

edges = [(0, 1, 3), (0, 2, 1), (1, 3, 1), (1, 4, 4), (1, 5, 2), (5, 6, 5), (5, 7, 6), (2, 8, 7), (7, 9, 2), (7, 10, 8)]

# 创建边
for edge in edges:
    u, v, l = edge
    nodes[u].add_edge(v, l)
    nodes[v].add_edge(u, l)

由于我们只是为了传达思路,所以省去了许多面向对象的代码,但是对于我们理解题目思路来说应该是够了。

下面,我们来看树上做动态规划的代码:

def dfs(u, f, ans):
    nodeu = nodes[u]
    # 遍历节点u所有的边
    for edge in nodes[u].edges:
        v, l = edge
        # 注意,这其中包括了父节点的边
        # 所以我们要判断v是不是父节点的id
        if v == f:
            continue
        # 递归,更新答案
        ans = max(ans, dfs(v, u, ans))
        nodev = nodes[v]
        # 转移最大值和次大值
        if nodev.max1 + l > nodeu.max1:
            nodeu.max1 = nodev.max1 + l
        elif nodev.max1 + l > nodeu.max2:
            nodeu.max2 = nodev.max1 + l
    # 返回当前最优解
    return max(ans, nodeu.max1 + nodeu.max2)

看起来很复杂的树形DP,其实代码也就只有十来行,是不是简单得有些出人意料呢?

但是还是老生常谈的话题,这十几行代码看起来简单,但是其中的细节还是有一些的,尤其是涉及到了递归操作。对于递归不是特别熟悉的同学可能会有些吃力,建议可以根据之前的图手动在纸上验算一下,相信会有更深刻的认识。

另一种做法

文章还没完,我们还有一个小彩蛋。其实这道题还有另外一种做法,这种做法非常机智,也一样介绍给大家。

之前我们说了,由于树记录的是节点的连通状态,所以不管以哪个节点为根节点,都不会影响整棵树当中路径的长度以及结构。既然如此,如果我们富有想象力的话,我们把一棵树压扁,是不是可以看成是一串连在一起的绳子或者木棍?

我们来看下图:

我们把C点向B点靠近,并不会影响树的结构,毕竟这是一个抽象出来的架构,我们并不关注树上树枝之间的夹角。我们可以想象成我们拎起了A点,其他的几点由于重力的作用下垂,最后就会被拉成一条直线。

比如上图当中,我们拎起了A点,BCD都垂下。这个时候位于最下方的点是D点。那么我们再拎起D点,最下方的点就成了C点,那么DC之间的距离就是树上的最长链路:

我们把整个过程梳理一下,首先我们随便选了一个点作为树根,然后找出了距离它最远的点。第二次,我们选择这个最远的点作为树根,再次找到最远的点。这两个最远点之间的距离就是答案。

这种做法非常直观,但是我也想不到可以严谨证明的方法,有思路的小伙伴可以在后台给我留言。如果有些想不通的小伙伴可以自己试着用几根绳子连在一起,然后拎起来做个实验。看看这样拎两次得到的两个点,是不是树上距离最远的两个点。

最后,我们来看下代码:

def dfs(u, f, dis, max_dis, nd):
    nodeu = nodes[u]
    for edge in nodes[u].edges:
        v, l = edge
        if v == f:
            continue
        nodev = nodes[v]
        # 更新最大距离,以及最大距离的点
        if dis + l > max_dis:
            max_dis, nd = dis+l, nodev
        # 递归
        _max, _nd = dfs(v, u, dis+l, max_dis, nd)
        # 如果递归得到的距离更大,则更新
        if _max > max_dis:
            max_dis, nd = _max, _nd
    # 返回
    return max_dis, nd

# 第一次递归,获取距离最大的节点
_, nd = dfs(0, -1, 0, 0, None)
# 第二次递归,获取最大距离
dis, _ = dfs(nd.id, -1, 0, 0, None)
print(dis)

到这里,这道有趣的题目就算是讲解完了,不知道文中的两种做法大家都学会了吗?第一次看可能会觉得有些蒙,问题很多这是正常的,但核心的原理并不难,画出图来好好演算一下,一定可以得到正确的结果。