17
\$\begingroup\$

I'm refreshing some of my datastructures. I saw this as the perfect opportunity to get some feedback on my code.

I'm interested in:

Algorithm wise:

  1. Is my implementation correct? (The tests say so)
  2. Can it be sped up?
  3. Comparing my code to the one in the heapq module, it seems that they do not provide a heapq class, but just provide a set of operations that work on lists? Is this better?
  4. Many implementations I saw iterate over the elements using a while loop in the siftdown method to see if it reaches the end. I instead call siftdown again on the chosen child. Is this approach better or worse?
  5. I've considered to add a parameter to the constructor that specified the size of the list/array in advance. It would then at creation already assign a list of that size to the heap - which will be only partially used at the start. It can counter the effect of lists appending operations - which I believe can tend to be slow? The __last_index pointer will then indicate the part used in the array/list. I did not see this in other implementations, so I wasn't sure this would be a good thing.

Code wise:

  1. Is my code clean and readable?
  2. Do my test suffice (for say an interview)?
  3. Is the usage of subclasses MinHeap and MaxHeap & their comparer method that distincts them, a good approach to provide both type of heaps?
  4. Is the usage of the classmethod a good idea to provide a createHeap() function that creates a new heap object.
  5. Anything other that can help me improve this code or fancify it? ;-)

Heap implementation

class Heap(object):
    def __init__(self):
        self.__array = []
        self.__last_index = -1

    def push(self, value):
        """ 
            Append item on the back of the heap, 
            sift upwards if heap property is violated.
        """
        self.__array.append(value)
        self.__last_index += 1
        self.__siftup(self.__last_index)

    def pop(self):
        """ 
            Pop root element from the heap (if possible),
            put last element as new root and sift downwards till
            heap property is satisfied.

        """
        if self.__last_index == -1:
            raise IndexError("Can't pop from empty heap")
        root_value = self.__array[0]
        if self.__last_index > 0:  # more than one element in the heap
            self.__array[0] = self.__array[self.__last_index]
            self.__siftdown(0)
        self.__last_index -= 1
        return root_value

    def peek(self):
        """ peek at the root, without removing it """
        if not self.__array:
            return None
        return self.__array[0]

    def replace(self, new_value):
        """ remove root & put NEW element as root & sift down -> no need to sift up """
        if self.__last_index == -1:
            raise IndexError("Can't pop from empty heap")
        root_value = self.__array[0]
        self.__array[0] = new_value
        self.__siftdown(0)
        return root_value

    def heapify(self, input_list):
        """
            each leaf is a trivial subheap, so we may begin to call
            Heapify on each parent of a leaf.  Parents of leaves begin
            at index n/2.  As we go up the tree making subheaps out
            of unordered array elements, we build larger and larger
            heaps, joining them at the i'th element with Heapify,
            until the input list is one big heap.
        """
        n = len(input_list)
        self.__array = input_list
        self.__last_index = n-1
        for index in reversed(range(n//2)):
            self.__siftdown(index)

    @classmethod
    def createHeap(cls, input_list):
        """
            create an heap based on an inputted list.
        """
        heap = cls()
        heap.heapify(input_list)
        return heap

    def __siftdown(self, index):
        current_value = self.__array[index]
        left_child_index, left_child_value = self.__get_left_child(index)
        right_child_index, right_child_value = self.__get_right_child(index)
        # the following works because if the right_child_index is not None, then the left_child
        # is also not None => property of a complete binary tree, else left will be returned.
        best_child_index, best_child_value = (right_child_index, right_child_value) if right_child_index\
        is not None and self.comparer(right_child_value, left_child_value) else (left_child_index, left_child_value)
        if best_child_index is not None and self.comparer(best_child_value, current_value):
            self.__array[index], self.__array[best_child_index] =\
                best_child_value, current_value
            self.__siftdown(best_child_index)
        return


    def __siftup(self, index):
        current_value = self.__array[index]
        parent_index, parent_value = self.__get_parent(index)
        if index > 0 and self.comparer(current_value, parent_value):
            self.__array[parent_index], self.__array[index] =\
                current_value, parent_value
            self.__siftup(parent_index)
        return

    def comparer(self, value1, value2):
        raise NotImplementedError("Should not use the baseclass heap\
            instead use the class MinHeap or MaxHeap.")

    def __get_parent(self, index):
        if index == 0:
            return None, None
        parent_index =  (index - 1) // 2
        return parent_index, self.__array[parent_index]

    def __get_left_child(self, index):
        left_child_index = 2 * index + 1
        if left_child_index > self.__last_index:
            return None, None
        return left_child_index, self.__array[left_child_index]

    def __get_right_child(self, index):
        right_child_index = 2 * index + 2
        if right_child_index > self.__last_index:
            return None, None
        return right_child_index, self.__array[right_child_index]

    def __repr__(self):
        return str(self.__array[:self.__last_index+1])

    def __eq__(self, other):
        if isinstance(other, Heap):
            return self.__array == other.__array
        if isinstance(other, list):
            return self.__array == other
        return NotImplemented

class MinHeap(Heap):
    def comparer(self, value1, value2):
        return value1 < value2

class MaxHeap(Heap):
    def comparer(self, value1, value2):
        return value1 > value2

Tests

def manualTest():
    """
        Basic test to see step by step changes.
    """
    h = MinHeap()
    h.push(10)
    assert(h == [10])
    h.push(20)
    assert(h == [10, 20])
    h.push(5)
    assert(h == [5, 20, 10])
    h.push(8)
    assert(h == [5, 8, 10, 20])
    h.push(3)
    assert(h == [3, 5, 10, 20, 8])
    h.push(40)
    assert(h == [3, 5, 10, 20, 8, 40])
    h.push(50)
    assert(h == [3, 5, 10, 20, 8, 40, 50])
    h.push(1)
    assert(h == [1, 3, 10, 5, 8, 40, 50, 20])
    assert(h.pop() == 1)
    assert(h.pop() == 3)
    assert(h.pop() == 5)
    assert(h.pop() == 8)
    assert(h.pop() == 10)
    assert(h.pop() == 20)
    assert(h.pop() == 40)
    assert(h.pop() == 50)
    try:
        h.pop()
        assert(False) 
    except IndexError:  # check if assertion is thrown when heap is empty
        assert(True)
    # check createHeap classmethod.
    assert(MinHeap.createHeap([2,7,3,1,9,44,23]) == [1, 2, 3, 7, 9, 44, 23])
    assert(MaxHeap.createHeap([2,7,3,1,9,44,23]) == [44, 9, 23, 1, 7, 3, 2])


def automaticTest(sample_size):
    """
        Test creating a min & max heap, push random values
        on it and see if the popped values are sorted.
    """
    import random
    random_numbers = random.sample(range(100), sample_size)
    min_heap = MinHeap()
    max_heap = MaxHeap()
    for i in random_numbers:
        min_heap.push(i)
        max_heap.push(i)
    random_numbers.sort()
    for i in random_numbers:
        assert(min_heap.pop() == i)
    random_numbers.sort(reverse=True)
    for i in random_numbers:
        assert(max_heap.pop() == i)

automaticTest(20)
manualTest()
\$\endgroup\$

1 Answer 1

10
\$\begingroup\$

Thanks for sharing your code!

I won't cover all your questions but I will try my best.

(warning, long post incoming)

Is my implementation correct? (The tests say so)

As far as I tried to break it I'd say yes it's correct. But see below for more thorough testing methods.

Can it be sped up?

Spoiler alert: yes

First thing I did was to profile change slightly your test file (I called it test_heap.py) to seed the random list generation. I also changed the random.sample call to be more flexible with the sample_size parameter.

It went from

random_numbers = random.sample(range(100), sample_size)

to

random.seed(7777)
random_numbers = random.sample(range(sample_size * 3), sample_size)

So the population from random.sample is always greater than my sample_size. Maybe there is a better way?

I also set the sample size to be 50000 to have a decent size for the next step.

Next step was profiling the code with python -m cProfile -s cumtime test_heap.py . If you are not familiar with the profiler see the doc. I launch the command a few times to get a grasp of the variations in timing, that gives me a baseline for optimization. The original value was:

  7990978 function calls (6561934 primitive calls) in 3.235 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      5/1    0.000    0.000    3.235    3.235 {built-in method builtins.exec}
        1    0.002    0.002    3.235    3.235 test_heap.py:1(<module>)
        1    0.051    0.051    3.233    3.233 test_heap.py:43(automaticTest)
   100009    0.086    0.000    2.759    0.000 heap.py:15(pop)
1400712/100011    1.688    0.000    2.673    0.000 heap.py:70(__siftdown)
  1400712    0.386    0.000    0.386    0.000 heap.py:104(__get_left_child)
  1400712    0.363    0.000    0.363    0.000 heap.py:110(__get_right_child)
   100008    0.064    0.000    0.341    0.000 heap.py:6(push)
228297/100008    0.180    0.000    0.270    0.000 heap.py:85(__siftup)
  1430126    0.135    0.000    0.135    0.000 heap.py:127(comparer)
  1429684    0.128    0.000    0.128    0.000 heap.py:131(comparer)
   228297    0.064    0.000    0.064    0.000 heap.py:98(__get_parent)
        1    0.026    0.026    0.062    0.062 random.py:286(sample)

Now we have a target to beat and a few information on what takes time. I did not paste the entire list of function calls, it's pretty long but you get the idea.

A lot of time is spent in _siftdown and a lot less on _siftup, and a few functions are called many times so let's see if we can fix that.

(I should have started by _siftdown which was the big fish here but for some reason, I started by _siftup, forgive me)

Speeding up _siftup

Before:

def __siftup(self, index):
    current_value = self.__array[index]
    parent_index, parent_value = self.__get_parent(index)
    if index > 0 and self.comparer(current_value, parent_value):
        self.__array[parent_index], self.__array[index] =\
            current_value, parent_value
        self.__siftup(parent_index)
    return

After:

def __siftup(self, index):
    current_value = self.__array[index]
    parent_index = (index - 1) >> 1
    if index > 0:
        parent_value = self.__array[parent_index]
        if self.comparer(current_value, parent_value):
            self.__array[parent_index], self.__array[index] =\
                current_value, parent_value
            self.__siftup(parent_index)
    return

I changed the way to calculate parent_index because I looked at the heapq module source and they use it. (see here) but I couldn't see the difference in timing from this change alone.

Then I removed the call to _get_parent and made the appropriate change (kind of inlining it because function call are not cheap in Python) and the new time is

7762306 function calls (6333638 primitive calls) in 3.147 seconds

Function calls went down obviously but time only dropped around 70-80 millisecond. Not a great victory (a bit less than a 3% speedup). And readability was not improved so up to you if it is worth it.

Speeding up _siftdown

The first change was to improve readability.

Original version:

def __siftdown(self, index):
    current_value = self.__array[index]
    left_child_index, left_child_value = self.__get_left_child(index)
    right_child_index, right_child_value = self.__get_right_child(index)
    # the following works because if the right_child_index is not None, then the left_child
    # is also not None => property of a complete binary tree, else left will be returned.
    best_child_index, best_child_value = (right_child_index, right_child_value) if right_child_index\
    is not None and self.comparer(right_child_value, left_child_value) else (left_child_index, left_child_value)
    if best_child_index is not None and self.comparer(best_child_value, current_value):
        self.__array[index], self.__array[best_child_index] =\
            best_child_value, current_value
        self.__siftdown(best_child_index)
    return

V2:

def __siftdown(self, index): #v2
    current_value = self.__array[index]
    left_child_index, left_child_value = self.__get_left_child(index)
    right_child_index, right_child_value = self.__get_right_child(index)
    # the following works because if the right_child_index is not None, then the left_child
    # is also not None => property of a complete binary tree, else left will be returned.
    best_child_index, best_child_value = (left_child_index, left_child_value)
    if right_child_index is not None and self.comparer(right_child_value, left_child_value):
        best_child_index, best_child_value = (right_child_index, right_child_value)
    if best_child_index is not None and self.comparer(best_child_value, current_value):
        self.__array[index], self.__array[best_child_index] =\
            best_child_value, current_value
        self.__siftdown(best_child_index)
    return

I transformed the ternary assignment

best_child_index, best_child_value = (right_child_index, right_child_value) if right_child_index\
        is not None and self.comparer(right_child_value, left_child_value) else (left_child_index, left_child_value)

into

best_child_index, best_child_value = (left_child_index, left_child_value)
if right_child_index is not None and self.comparer(right_child_value, left_child_value):
    best_child_index, best_child_value = (right_child_index, right_child_value)

I find it a lot more readable but it's probably a matter of taste. And to my surprise, when I profiled the code again, the result was:

7762306 function calls (6333638 primitive calls) in 3.079 seconds

(I ran it 10times and I always had gained around 80-100 milliseconds). I don't really understand why, if anybody could explain to me?

V3:

def __siftdown(self, index): #v3
    current_value = self.__array[index]
    
    left_child_index = 2 * index + 1
    if left_child_index > self.__last_index:
        left_child_index, left_child_value = None, None
    else:
        left_child_value = self.__array[left_child_index]
    
    right_child_index = 2 * index + 2
    if right_child_index > self.__last_index:
         right_child_index, right_child_value = None, None
    else:
        right_child_value = self.__array[right_child_index]
    # the following works because if the right_child_index is not None, then the left_child
    # is also not None => property of a complete binary tree, else left will be returned.
    best_child_index, best_child_value = (left_child_index, left_child_value)
    if right_child_index is not None and self.comparer(right_child_value, left_child_value):
        best_child_index, best_child_value = (right_child_index, right_child_value)
    if best_child_index is not None and self.comparer(best_child_value, current_value):
        self.__array[index], self.__array[best_child_index] =\
            best_child_value, current_value
        self.__siftdown(best_child_index)
    return

Like in _siftup I inlined 2 calls from helper function _get_left_child and _get_right_child and that payed off!

4960546 function calls (3531878 primitive calls) in 2.206 seconds

That's a 30% speedup from the baseline.

(What follow is a further optimization that I try to explain but I lost the code I wrote for it, I'll try to right down again later. It might gives you an idea of the gain)

Then using the heapq trick of specializing comparison for max and min (using a _siftdown_max and _siftup_max version replacing comparer by > and doing the same for min) gives us to:

2243576 function calls (809253 primitive calls) in 1.780 seconds

I did not get further in optimizations but the _siftdown is still a big fish so maybe there is room for more optimizations? And pop and push maybe could be reworked a bit but I don't know how.

Comparing my code to the one in the heapq module, it seems that they do not provide a heapq class, but just provide a set of operations that work on lists? Is this better?

I'd like to know as well!

Many implementations I saw iterate over the elements using a while loop in the siftdown method to see if it reaches the end. I instead call siftdown again on the chosen child. Is this approach better or worse?

Seeing as function call are expensive, looping instead of recursing might be faster. But I find it better expressed as a recursion.

Is my code clean and readable?

For the most part yes! Nice code, you got docstrings for your public methods, you respect PEP8 it's all good. Maybe you could add documentation for the private method as well? Especially for hard stuff like _siftdown and _siftup.

Just a few things:

  • the ternary I changed in _siftdown I consider personally really hard to read.

  • comparer seems like a French name, why not compare? Either I missed something or you mixed language and you shouldn't.

Do my test suffice (for say an interview)?

I'd say no. Use a module to do unit testing. I personally like pytest.

You prefix the name of your testing file by test_ and then your tests methods are prefixed/suffixed by test_/_test. Then you just run pytest on the command line and it discovers tests automatically, run them and gives you a report. I highly recommend you try it.

Another great tool you could have used is hypothesis which does property-based testing. It works well with pytest.

An example for your case:

from hypothesis import given, assume
import hypothesis.strategies as st

@given(st.lists(st.integers()))
def test_minheap(l):
    h = MinHeap.createHeap(l)
    s = sorted(l)
    for i in range(len(s)):
        assert(h.pop() == s[i])
        
@given(st.lists(st.integers()))
def test_maxheap(l):
    h = MaxHeap.createHeap(l)
    s = sorted(l, reverse=True)
    for i in range(len(s)):
        assert(h.pop() == s[i])

It pretty much gives the same kind of testing you did in your automatic_test but gets a bunch of cool feature added, and is shorter to write.

Raymond Hettinger did a really cool talk about tools to use when testing on a short time-budget, he mention both pytest and hypothesis, go check it out :)

Is the usage of subclasses MinHeap and MaxHeap & their comparer method that distincts them, a good approach to provide both type of heaps?

I believe it is! But speed wise, you should instead redeclare siftdown and siftup in the subclasses and replace instance of compare(a,b) by a < b or a > b in the code.

End note

Last thing is a remark, on wikipedia, the article say:

sift-up: move a node up in the tree, as long as needed; used to restore heap condition after insertion. Called "sift" because node moves up the tree until it reaches the correct level, as in a sieve.

sift-down: move a node down in the tree, similar to sift-up; used to restore heap condition after deletion or replacement.

And I think you used it in this context but on the heapq module implementation it seems to have the name backward?

They use siftup in pop and siftdown in push while wikipedia tells us to do the inverse. Somebody can explain please?

(I asked this question on StackOverflow, hopefully I'll get a response)

\$\endgroup\$

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.