ACMer不得不会的线段树,究竟是种怎样的数据结构?

大家好,欢迎阅读周三算法数据结构专题,今天我们来聊聊一个新的数据结构,叫做线段树。

线段树这个数据结构很多人可能会有点蒙,觉得没有听说过,但是它非常非常有名,尤其是在竞赛圈,可以说是竞赛圈的必备技能。所以如果以后遇到有人看了一点算法导论就在你面前装逼,你就可以问他:请问线段树更新的复杂度是多少?

不过如果你会线段树,你也要小心一点,最好不要在面试的时候随便透露你会这个算法。否则面试官一下子就会知道你是圈里人,然后你会发现你后面的面试问题比之前好像难不少。当然也有可能遇到面试官自己不会,为了防止尴尬强行让你用非线段树的解法来完成,比如我就遇到过……

例题

说了这么多废话,那么线段树究竟是什么呢?线段树的英文是segment tree,其实也算是一个直译。因为这个数据结构和线段没有特别大的关系,我个人感觉翻译成区间树可能更贴近一点。

我们先理解到这里,就是这个数据结构大概和区间有点关系。我们先放一放,先来看一道例题,来实际体会一下,为什么需要线段树这个数据结构,以及它的使用场景究竟是什么。这样我们可以对它有一个更加直观的感受,这道题很简单也很经典,我就是在这道题遇到了面试官不让用线段树的突然袭击。

这道题的题面是这样,给定一个长度为n的数组。这个数组当中有n个整数,然后我们会有两种操作。一种操作叫更新,我们指定更新某一个位置的某个数,第二个操作叫query,给定一个区间,要求这个区间里面元素的最小值。n的范围呢是,操作的数量也是,请问我们应该怎么实现?

线段树概念

当然你可能已经知道要用线段树了,只是不知道线段树是什么以及怎么使用。我们先把这些疑惑放在一边,就单纯简单地用最朴素的方法来思考的话,我们会发现我们每次查询都是的操作。最坏的情况下,我们就是要求整个数组的最小值,那么我们需要依次遍历整个区间来求。那么复杂度再乘上操作的数量,整个程序的复杂度会达到。显然这是一个非常巨大的数字,在算法竞赛场景当中一定会超时。

也就是说简单粗暴是做不出来的,如果你有足够多的做题经验,你就会很自然地想到我们也许需要使用一些数据结构来优化这个查询的复杂度。肯定是不能接受的,即使不能优化到,也至少可以试试。线段树就是这样的数据结构,我们直接来看一张图,我们直接就可以搞明白线段树究竟是干嘛的,以及它的工作原理。

这张图当中的a就是我们存数据的数组,这个数组上面的就是线段树。我们从上往下看,给大家解释一下。最上面一条只有一个数字就是1,它代表的是整个数组的最小值是1。也就是说最上层维护的是整个区间的最小值。然后是第二层,在第二层我们看到了两个数,分别是3和1。很明显,3表示的是左半边区间的最小值,1表示的右半边区间的最小值。

到了第三行我们得到了4个数,同理,再下一层有8个数。很明显这是一颗二叉树,并且二叉树当中的每一个节点维护了一个区间的值。它的叶子节点存储的是长度为1的区间,也就是单个元素。我们把两个兄弟节点维护的区间合并起来就得到了父节点的区间。在这道题当中,由于我们维护的是区间的最小值,所以我们可以得到这么一个式子:

node.min = min(node.left.min, node.right.min)

所以线段树就是利用了二叉树这个层次结构对一个区间进行维护的数据结构。

线段树查询

我们已经了解了线段树的结构了,剩下的就只有两个问题,一个是如何更新一个是如何求解。我发先来看求解,我们要求一个区间的最小值。我们来实际看一下,假设我们想要查询下标是[2, 5]这个区间里的最小值怎么办?

我们对照一下上面的数组a,下标[3, 6]这个区间对应的是[7, 9, 6, 4]这四个值。我们会发现不存在刚好只包含这四个值的区间,那怎么办呢?其实很简单,可以拼凑。我们可以发现我们可以把这个完整的区间转化成两个区间连接在一起的结果。比如下图这样。

这样,我们就把原本比较[7, 9, 6, 4]四个值的一个查询行为转化成了只需要比较4和7两个值大小的比较行为了。这可以替我们节约大量的时间。这和记忆化搜索有一点点像,相当于我们制定一个模式,根据这个模式把区间里的最值存储下来。这样我们查询的时候可以利用这些值来快速求解。

如果我们要求[2, 7]区间内的最小值,那么我们可以转而用这两个区间的值求到。

线段树更新

接下来我们来看下线段树的更新,其实更新和查询的原理是一样的,同样是从根节点出发一层层往下,一直到更新到叶子节点为止。假如说我们把数据当中的4更新成0,那么会达成一种怎样的效果呢?

从结果上来看,我们是把发生变更的叶子节点到树根的这一整个链路都更新了。当然这个更新也不是强制发生的,因为如果我们更新的值比它的原值1要大的话,也是不会更新的。

代码实现

关于线段树的原理我们就差不多讲完了,看起来不太长,这是很正常的。因为线段树的原理其实很简单,就是用一棵二叉树来维护各个长度的区间。我们在查询的时候就是要找到可以拼成我们查询的区间的几个子区间,用这些子区间的值来求到我们要查的区间的值。在我们更新的时候,不需要更新整棵树,只需要更新某一条从根节点到叶子节点的路径就可以了。

原理看起来不难,理解起来也不难,但是要用代码实现出来其实不太容易。因为线段树的所有操作都是基于递归和回溯的,所以想要顺利、深入地理解线段树,对于递归以及回溯的掌握一定要过关。否则线段树你写起来很痛苦,写完了调试会更痛苦。

我们会用面向对象的形式来创建一个线段树,当然也有人喜欢用数组来模拟,这也是可以的,本质上都是一样的。首先我们来创建一个节点类。这个节点类存储的值有3个,一个是它维护的区间的值,在这个题目里维护的是区间最小值。一个是区间的范围, 左右边界。另外一个是左右孩子节点。

由于我们在创建节点的时候还不知道它的左右孩子以及维护的值是什么,所以我们先赋值成None。

class Node:
    def __init__(self, left_side, right_side):
        self.val = None
        self.ls, self.rs = left_side, right_side
        self.left_child, self.right_child = None, None

Node类有了之后,我们就可以利用它来建树了。我们首先来看看建树的方法,也就是常说的build方法。我们创建线段树的时候最重要的就是让它当中的每一个节点能够存储对应区间的最小值。但是呢由于线段树是有层次结构的,我们在创建区间[a, b]的时候,其实可以利用区间[a, m]和区间[m+1, b]两个区间的最小值来获取整个区间的最小值。也就是说我们可以利用当前节点的左右孩子节点完成,我们之前已经说过这点了。

我们来看代码,通过递归可以很方便地完成这一点。

class SegmentTree:
    def __init__(self, arr):
        self.n = len(arr)
        self.vals = arr[:]
        self.root = self.build(0, self.n)

    def build(self, l, r):
        # 传入的l和r表示区间范围,左闭右开
        if r - l < 1:
            return None
        node = Node(l, r)
        # 如果区间长度是1,说明是叶子节点了,直接将val赋值成对应的数值
        if r - l == 1:
            node.val = self.vals[l]
        else:
            # 否则递归调用
            m = (l + r) >> 1
            node.left_child = self.build(l, m)
            node.right_child = self.build(m, r)
            node.val = min(node.left_child.val, node.right_child.val)
        return node

当然这个过程也可以用循环实现,只不过用递归实现更加简单。

如果你能看得到build方法,那么update和query对你来说也都不是问题,其实原理都是一样的,只不过一个是通过递归的形式去更新一个是递归去查询而已。我们先来看update:

    def update(self, k, v):
        self._update(self.root, k, v)

    def _update(self, u, k, v):
        if u is None:
            return
        # 如果k在u这个节点维护的区间里
        if u.ls <= k < u.rs:
            # 更新它的最小值
            u.val = min(u.val, v)
            m = (u.ls + u.rs) >> 1
            # 判断往左还是往右
            if k < m:
                self._update(u.left_child, k, v)
            else:
                self._update(u.right_child, k, v)

最后我们再来看query,query同样是通过递归执行的。由于我们查询的是一个区间,所以我们需要判断我们查询区间和节点维护区间之间的关系。只要抓住了这一点,整个逻辑也是很简单的。

    def query(self, l, r):
        return self._query(self.root, l, r)

    def _query(self, u, l, r):
        # l和r是查询区间
        # 如果查询区间是u节点区间的超集
        if l <= u.ls and r >= u.rs:
            return u.val
        # 如果查询区间只和u节点区间的左半部分有交集
        elif r <= u.left_child.rs:
            return self._query(u.left_child, l, r)
        # 如果查询区间只和u节点右半部分有交集
        elif l >= u.right_child.ls:
            return self._query(u.right_child, l, r)
        # 如果都有交集
        return min(self._query(u.left_child, l, r), self._query(u.right_child, l, r))

最后

到这里,我们关于线段树的基本介绍就算是结束了。注意我说的是基本介绍,因为线段树有很多种用法,今天介绍的只是其中最简单的一种:单点更新区间查询。除此之外还有区间更新单点查询,区间更新区间查询,扫描线等等相对高端一些的用法。由于篇幅所限不能一次讲完,准备放在之后的文章当中分享给大家。

另外一点市面上线段树的题目基本上都是用C++写的,所以如果你想要找一道题试一下的话,可能需要用C++重新写一遍。不过我相信这对于你们来说并不是什么大问题。

今天的文章到这里就结束了,如果喜欢本文的话,请给我一波三连支持吧(关注、转发、点赞)。

原文链接,求个关注

本文使用 mdnice 排版

– END –

{{uploading-image-370795.png(uploading…)}}