Python实现最大堆(大顶堆)

  • 2020 年 1 月 10 日
  • 筆記

最大堆是指最大的元素在堆顶的堆。

Python自带的heapq模块实现的是最小堆,没有提供最大堆的实现。虽然有些文章通过把元素取反再放入堆,出堆时再取反,把问题转换为最小堆问题也能间接实现最大堆,但是这样的实现只适合数值型的元素,不适合自定义类型。

下面给出实现代码:

# -*- coding: UTF-8 -*-    import random      class MaxHeap(object):        def __init__(self):          self._data = []          self._count = len(self._data)        def size(self):          return self._count        def isEmpty(self):          return self._count == 0        def add(self, item):          # 插入元素入堆          self._data.append(item)          self._count += 1          self._shiftup(self._count-1)        def pop(self):          # 出堆          if self._count > 0:              ret = self._data[0]              self._data[0] = self._data[self._count-1]              self._count -= 1              self._shiftDown(0)              return ret        def _shiftup(self, index):          # 上移self._data[index],以使它不大于父节点          parent = (index-1)>>1          while index > 0 and self._data[parent] < self._data[index]:              # swap              self._data[parent], self._data[index] = self._data[index], self._data[parent]              index = parent              parent = (index-1)>>1        def _shiftDown(self, index):          # 上移self._data[index],以使它不小于子节点          j = (index << 1) + 1          while j < self._count :              # 有子节点              if j+1 < self._count and self._data[j+1] > self._data[j]:                  # 有右子节点,并且右子节点较大                  j += 1              if self._data[index] >= self._data[j]:                  # 堆的索引位置已经大于两个子节点,不需要交换了                  break              self._data[index], self._data[j] = self._data[j], self._data[index]              index = j              j = (index << 1) + 1    # 元素是数值类型  def testIntValue():      for iTimes in range(10):          iLen = random.randint(1,300)          allData= random.sample(range(iLen*100), iLen)  #         allData = [1, 4, 3, 2, 5, 7, 6]  #         iLen = len(allData)          print('nlen =',iLen)            oMaxHeap = MaxHeap()          print('_data:t   ', allData)          arrDataSorted = sorted(allData, reverse=True)          print('dataSorted:', arrDataSorted)          for i in allData:              oMaxHeap.add(i)          heapData = []          for i in range(iLen):              iExpected = arrDataSorted[i]              iActual = oMaxHeap.pop()              heapData.append(iActual)              print('{0}, expected: {1}, actual: {2}'.format(iExpected==iActual, iExpected, iActual))              assert iExpected==iActual, ""          print('dataSorted:', arrDataSorted)          print('heapData:  ',heapData)    # 元素是元祖类型  def testTupleValue():      for iTimes in range(10):          iLen = random.randint(1,300)          listData= random.sample(range(iLen*100), iLen)  #         listData = [1, 4, 3, 2, 5, 7, 6]  #         iLen = len(listData)          # 注意:key作为比较大小的关键          allData = dict(zip(listData, [str(e) for e in listData]))          print('nlen =',iLen)          print('allData: ', allData)            oMaxHeap = MaxHeap()          arrDataSorted = sorted(allData.items(), key=lambda d:d[0], reverse=True)  #         arrDataSorted = sorted(allData, reverse=True)          print('dataSorted:', arrDataSorted)          for (k,v) in allData.items():              oMaxHeap.add((k,v)) # 元祖的第一个元素作为比较点          heapData = []          for i in range(iLen):              iExpected = arrDataSorted[i]              iActual = oMaxHeap.pop()              heapData.append(iActual)              print('{0}, expected: {1}, actual: {2}'.format(iExpected==iActual, iExpected, iActual))              assert iExpected==iActual, ""          print('dataSorted:', arrDataSorted)          print('heapData:  ',heapData)    # 元素是自定义类  def testClassValue():        class Model4Test(object):          '''          用于放入到堆的自定义类。注意要重写__lt__、__ge__、__le__和__cmp__函数。          '''          def __init__(self, sUid, value):              self._sUid = sUid              self._value = value            def getUid(self):              return self._sUid            def getValue(self):              return self._value            # 类类型,使用的是小于号_lt_          def __lt__(self, other):#operator <  #             print('in __lt__(self, other)')              return self.getValue() < other.getValue()            def __ge__(self,other):#oprator >=              return self.getValue() >= other.getValue()            #下面两个方法重写一个就可以了          def __le__(self,other):#oprator <=              return self.getValue() <= other.getValue()            def __cmp__(self,other):              #call global(builtin) function cmp for int              return super.cmp(self.getValue(),other.getValue())            def __str__(self):              return '({0}, {1})'.format(self._value, self._sUid)        for iTimes in range(10):          iLen = random.randint(1,300)          listData = random.sample(range(iLen*100), iLen)  #         listData = [1, 4, 3, 2, 5, 7, 6]          allData = [Model4Test(str(value), value) for value in listData]          print('allData:   ', [str(e) for e in allData])          iLen = len(allData)          print('nlen =',iLen)            oMaxHeap = MaxHeap()          arrDataSorted = sorted(allData, reverse=True)          print('dataSorted:', [str(e) for e in arrDataSorted])          for i in allData:              oMaxHeap.add(i)          heapData = []          for i in range(iLen):              iExpected = arrDataSorted[i]              iActual = oMaxHeap.pop()              heapData.append(iActual)              print('{0}, expected: {1}, actual: {2}'.format(iExpected==iActual, iExpected, iActual))              assert iExpected==iActual, ""          print('dataSorted:', [str(e) for e in arrDataSorted])          print('heapData:  ', [str(e) for e in heapData])    if __name__ == '__main__':      testIntValue()      testTupleValue()      testClassValue()