Python实现最大堆(大顶堆)
- 2020 年 1 月 10 日
- 筆記
最大堆是指最大的元素在堆顶的堆。
Python自带的heapq模块实现的是最小堆,没有提供最大堆的实现。虽然有些文章通过把元素取反再放入堆,出堆时再取反,把问题转换为最小堆问题也能间接实现最大堆,但是这样的实现只适合数值型的元素,不适合自定义类型。
下面给出实现代码:
# -*- coding: UTF-8 -*- import random class MaxHeap(object): def __init__(self): self._data = [] self._count = len(self._data) def size(self): return self._count def isEmpty(self): return self._count == 0 def add(self, item): # 插入元素入堆 self._data.append(item) self._count += 1 self._shiftup(self._count-1) def pop(self): # 出堆 if self._count > 0: ret = self._data[0] self._data[0] = self._data[self._count-1] self._count -= 1 self._shiftDown(0) return ret def _shiftup(self, index): # 上移self._data[index],以使它不大于父节点 parent = (index-1)>>1 while index > 0 and self._data[parent] < self._data[index]: # swap self._data[parent], self._data[index] = self._data[index], self._data[parent] index = parent parent = (index-1)>>1 def _shiftDown(self, index): # 上移self._data[index],以使它不小于子节点 j = (index << 1) + 1 while j < self._count : # 有子节点 if j+1 < self._count and self._data[j+1] > self._data[j]: # 有右子节点,并且右子节点较大 j += 1 if self._data[index] >= self._data[j]: # 堆的索引位置已经大于两个子节点,不需要交换了 break self._data[index], self._data[j] = self._data[j], self._data[index] index = j j = (index << 1) + 1 # 元素是数值类型 def testIntValue(): for iTimes in range(10): iLen = random.randint(1,300) allData= random.sample(range(iLen*100), iLen) # allData = [1, 4, 3, 2, 5, 7, 6] # iLen = len(allData) print('nlen =',iLen) oMaxHeap = MaxHeap() print('_data:t ', allData) arrDataSorted = sorted(allData, reverse=True) print('dataSorted:', arrDataSorted) for i in allData: oMaxHeap.add(i) heapData = [] for i in range(iLen): iExpected = arrDataSorted[i] iActual = oMaxHeap.pop() heapData.append(iActual) print('{0}, expected: {1}, actual: {2}'.format(iExpected==iActual, iExpected, iActual)) assert iExpected==iActual, "" print('dataSorted:', arrDataSorted) print('heapData: ',heapData) # 元素是元祖类型 def testTupleValue(): for iTimes in range(10): iLen = random.randint(1,300) listData= random.sample(range(iLen*100), iLen) # listData = [1, 4, 3, 2, 5, 7, 6] # iLen = len(listData) # 注意:key作为比较大小的关键 allData = dict(zip(listData, [str(e) for e in listData])) print('nlen =',iLen) print('allData: ', allData) oMaxHeap = MaxHeap() arrDataSorted = sorted(allData.items(), key=lambda d:d[0], reverse=True) # arrDataSorted = sorted(allData, reverse=True) print('dataSorted:', arrDataSorted) for (k,v) in allData.items(): oMaxHeap.add((k,v)) # 元祖的第一个元素作为比较点 heapData = [] for i in range(iLen): iExpected = arrDataSorted[i] iActual = oMaxHeap.pop() heapData.append(iActual) print('{0}, expected: {1}, actual: {2}'.format(iExpected==iActual, iExpected, iActual)) assert iExpected==iActual, "" print('dataSorted:', arrDataSorted) print('heapData: ',heapData) # 元素是自定义类 def testClassValue(): class Model4Test(object): ''' 用于放入到堆的自定义类。注意要重写__lt__、__ge__、__le__和__cmp__函数。 ''' def __init__(self, sUid, value): self._sUid = sUid self._value = value def getUid(self): return self._sUid def getValue(self): return self._value # 类类型,使用的是小于号_lt_ def __lt__(self, other):#operator < # print('in __lt__(self, other)') return self.getValue() < other.getValue() def __ge__(self,other):#oprator >= return self.getValue() >= other.getValue() #下面两个方法重写一个就可以了 def __le__(self,other):#oprator <= return self.getValue() <= other.getValue() def __cmp__(self,other): #call global(builtin) function cmp for int return super.cmp(self.getValue(),other.getValue()) def __str__(self): return '({0}, {1})'.format(self._value, self._sUid) for iTimes in range(10): iLen = random.randint(1,300) listData = random.sample(range(iLen*100), iLen) # listData = [1, 4, 3, 2, 5, 7, 6] allData = [Model4Test(str(value), value) for value in listData] print('allData: ', [str(e) for e in allData]) iLen = len(allData) print('nlen =',iLen) oMaxHeap = MaxHeap() arrDataSorted = sorted(allData, reverse=True) print('dataSorted:', [str(e) for e in arrDataSorted]) for i in allData: oMaxHeap.add(i) heapData = [] for i in range(iLen): iExpected = arrDataSorted[i] iActual = oMaxHeap.pop() heapData.append(iActual) print('{0}, expected: {1}, actual: {2}'.format(iExpected==iActual, iExpected, iActual)) assert iExpected==iActual, "" print('dataSorted:', [str(e) for e in arrDataSorted]) print('heapData: ', [str(e) for e in heapData]) if __name__ == '__main__': testIntValue() testTupleValue() testClassValue()