假设你需要维护一个列表,这个列表不断有新的元素加入,你需要在任何时候很方便的得到列表中的最大(小)值,因此要求列表始终处于排序完毕状态,。你会怎么做?

一个最简单的方法就是每次插入新的数据时,调用一次sort方法,这样可以保证列表的顺序。在数据量很小的情况下,这种方法可行,但如果数据量很大呢?要知道,Python中列表的sort方法实现并不高明,采用了一种不太有名的自然归并排序,虽然排序开销已经被尽量的压缩了,但仍然不是很理想,复杂度大概是O(nlogn)。

有没有更好的实现方法呢?答案是肯定的!在数据结构的世界里,只有想不到,没有做不到。

另一种解决方案就是heapq,它是Python的一个标准库。heapq实现了一种叫做的数据结构,是一种简洁的二叉树。他能确保父节点总是比子节点小,即满足

#Python code
list[i] < = list[2*i + 1] and list[i] <= list[2*i + 2]

因此,list[0]就是最小的元素。在Python中维护一个堆最好的方式就是使用列表,并用库模块heapq来管理此列表。这个列表无需完成排序,但你却能够确保每次调用heappop从列表中获取元素时,总是当前最小的元素,然后所有节点会自动调整,以确保堆特性仍然有效。每次通过heappush添加元素或通过heappop删除元素时,开销大概是O(logn),在数据量很大时,明显要好于排序的方法。

下面,我将通过一个例子来说明适合堆使用的场景。

假设有一个很长的列表,并且周期性的有新的数据到达,你总是希望能够从队列中获取最重要的元素,而无需不断的重新排序或在整个队列中搜索。这个概念叫做优先级队列,而堆正是最适合实现他的数据结构。注意,heapq模块在每次调用heappop时向你提供最小的元素,因此需要安排你的元素的优先级值,以反应出元素的这个特点。举个例子,假设你每次收到一个数据都付一份钱,而任何时候最重要的元素都是队列中价格最高的那个;另外对于价格相同的元素,先到达的重要一些。下面的代码就是遵循这个要求,使用heapq实现的“优先级队列”类。

class prioq(object):
    def __init__(self):
        self.q = []
        self.i = 0;

    def push(self, item, cost):
        heapq.heappush(self.q, (-cost, self.i, item))
        self.i += 1

    def pop(self):
        return heapq.heappop(self.q)
 

代码中,将价格置为负数,作为原组的第一个元素,并将整个原组压入堆中,这样更高的出价便会产生更小的原组(基于Python的自然比较方式),在价钱之后,我们放置了一个递增索引,这样,当元素拥有相同的价钱时,先到达的元素将会处于更小的原组中。

需要说明的一点是,堆本身并不是一种有序的结构,但可以通过遍历二叉树的方式得到有序的列表。堆排序就是这么做的。

另外,Python在2.3中引入heapq模块,在2.4版本中又被重新实现和进一步优化了。更详细的使用说明,请参考Python标准库文档

 

无意间看了一下 Python heapq 模块的源代码,发现里面的源代码写得很棒,忍不住拿出来和大家分享一下。:)

heapq 的意思是 heap queue,也就是基于 heap 的 priority queue。说到 priority queue,忍不住吐槽几句。我上学的时候学优先级队列的时候就没有碰到像 wikipedia 上那样透彻的解释,priority queue 并不是具体的某一个数据结构,而是对一类数据结构的概括!比如栈就可以看作是后进入的优先级总是大于先进入的,而队列就可以看成是先进入的优先级一定高于后进来的!这还没完!如果我们是用一个无序的数组实现一个priority queue,对它进行出队操作尼玛不就是选择排序么!尼玛用有序数组实现入队操作尼玛不就是插入排序么!尼玛用堆实现(即本文要介绍的)就是堆排序啊!用二叉树实现就是二叉树排序啊!数据结构课学的这些东西基本上都出来了啊!T_T

heapq 的代码也是用 Python 写的,用到了一些其它 Python 模块,如果你对 itertools 不熟悉,在阅读下面的代码之前请先读文档。heapq 模块主要有5个函数:heappush(),把一个元素放入堆中;heappop(),从堆中取出一个元素;heapify(),把一个列表变成一个堆;nlargest() 和 nsmallest() 分别提供列表中最大或最小的N个元素。

先从简单的看起:

PYTHON:

  1. def heappush(heap, item):
  2.     """Push item onto heap, maintaining the heap invariant."""
  3.     heap.append(item)
  4.     _siftdown(heap, 0len(heap)-1)
  5.  
  6. def heappop(heap):
  7.     """Pop the smallest item off the heap, maintaining the heap invariant."""
  8.     lastelt = heap.pop()    # raises appropriate IndexError if heap is empty
  9.     if heap:
  10.         returnitem = heap[0]
  11.         heap[0] = lastelt
  12.         _siftup(heap, 0)
  13.     else:
  14.         returnitem = lastelt
  15.     return returnitem

 

从源代码我们不难看出,这个堆也是用数组实现的,而且是最小堆,即 a[k] < = a[2*k+1] and a[k] <= a[2*k+2]。

heappush() 先把新元素放到末尾,然后把它一直往前移,直到位置合适,调整位置的操作是由 _siftdown() 来完成的。heappop()也不难理解,如果堆中只有一个元素(最小的那个)的话,省去了_siftup()。这两个函数是整个模块的核心:

PYTHON:

  1. def _siftdown(heap, startpos, pos):
  2.     newitem = heap[pos]
  3.     # Follow the path to the root, moving parents down until finding a place
  4.     # newitem fits.
  5.     while pos> startpos:
  6.         parentpos = (pos - 1)>> 1
  7.         parent = heap[parentpos]
  8.         if cmp_lt(newitem, parent):
  9.             heap[pos] = parent
  10.             pos = parentpos
  11.             continue
  12.         break
  13.     heap[pos] = newitem
  14.  
  15. def _siftup(heap, pos):
  16.     endpos = len(heap)
  17.     startpos = pos
  18.     newitem = heap[pos]
  19.     # Bubble up the smaller child until hitting a leaf.
  20.     childpos = 2*pos + 1    # leftmost child position
  21.     while childpos <endpos:
  22.         # Set childpos to index of smaller child.
  23.         rightpos = childpos + 1
  24.         if rightpos <endpos and not cmp_lt(heap[childpos], heap[rightpos]):
  25.             childpos = rightpos
  26.         # Move the smaller child up.
  27.         heap[pos] = heap[childpos]
  28.         pos = childpos
  29.         childpos = 2*pos + 1
  30.     # The leaf at pos is empty now.  Put newitem there, and bubble it up
  31.     # to its final resting place (by sifting its parents down).
  32.     heap[pos] = newitem
  33.     _siftdown(heap, startpos, pos)

 

上面的代码加上注释很容易理解,不是吗?在此基础上实现 heapify() 就很容易了:

PYTHON:

  1. def heapify(x):
  2.     """Transform list into a heap, in-place, in O(len(x)) time."""
  3.     n = len(x)
  4.     # Transform bottom-up.  The largest index there's any point to looking at
  5.     # is the largest with a child index in-range, so must have 2*i + 1 <n,
  6.     # or i <(n-1)/2.  If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so
  7.     # j-1 is the largest, which is n//2 - 1.  If n is odd = 2*j+1, this is
  8.     # (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1.
  9.     for i in reversed(xrange(n//2)):
  10.         _siftup(x, i)

这里用了一个技巧,正如注释中所说,其实只要 siftup 后面一半就可以了,前面的一半自然就是一个heap了。heappushpop() 也可以用它来实现了,而且用不着调用 heappush()+heappop():

PYTHON:

  1. def heappushpop(heap, item):
  2.     """Fast version of a heappush followed by a heappop."""
  3.     if heap and cmp_lt(heap[0], item):
  4.         item, heap[0] = heap[0], item
  5.         _siftup(heap, 0)
  6.     return item

 

第一眼看到这个函数可能觉得它放进去一个再取出来一个有什么意思嘛!仔细想想它很有用,尤其是在后面实现 nlargest() 的时候:

PYTHON:

  1. def nlargest(n, iterable):
  2.     """Find the n largest elements in a dataset.
  3.  
  4.    Equivalent to:  sorted(iterable, reverse=True)[:n]
  5.    """
  6.     if n <0:
  7.         return []
  8.     it = iter(iterable)
  9.     result = list(islice(it, n))
  10.     if not result:
  11.         return result
  12.     heapify(result)
  13.     _heappushpop = heappushpop
  14.     for elem in it:
  15.         _heappushpop(result, elem)
  16.     result.sort(reverse=True)
  17.     return result

 

先从 list 中取出 N 个元素来,然后把这个 list 转化成 heap,把原先的 list 中的所有元素在此 heap 上进行 heappushpop() 操作,最后剩下的一定是最大的!因为你每次 push 进去的不一定是最大的,但你 pop 出来的一定是最小的啊!

但 nsmallest() 的实现就截然不同了:

PYTHON:

  1. def nsmallest(n, iterable):
  2.     """Find the n smallest elements in a dataset.
  3.  
  4.    Equivalent to:  sorted(iterable)[:n]
  5.    """
  6.     if n <0:
  7.         return []
  8.     if hasattr(iterable, '__len__') and n * 10 <len(iterable):
  9.         # For smaller values of n, the bisect method is faster than a minheap.
  10.         # It is also memory efficient, consuming only n elements of space.
  11.         it = iter(iterable)
  12.         result = sorted(islice(it, 0, n))
  13.         if not result:
  14.             return result
  15.         insort = bisect.insort
  16.         pop = result.pop
  17.         los = result[-1]    # los --> Largest of the nsmallest
  18.         for elem in it:
  19.             if cmp_lt(elem, los):
  20.                 insort(result, elem)
  21.                 pop()
  22.                 los = result[-1]
  23.         return result
  24.     # An alternative approach manifests the whole iterable in memory but
  25.     # saves comparisons by heapifying all at once.  Also, saves time
  26.     # over bisect.insort() which has O(n) data movement time for every
  27.     # insertion.  Finding the n smallest of an m length iterable requires
  28.     #    O(m) + O(n log m) comparisons.
  29.     h = list(iterable)
  30.     heapify(h)
  31.     return map(heappop, repeat(h, min(n, len(h))))

 

这里做了一个优化,如果 N 小于 list 长度的1/10的话,就直接用插入排序(因为之前就是有序的,所以很快)。如果 N 比较大的话,就用 heap 来实现了,把整个 list 作成一个堆,然后 heappop() 出来的前 N 个就是最后的结果了!不得不说最后一行代码写得太精炼了!