4
\$\begingroup\$

4sum problem

Given an array S of n integers, are there elements a, b, c, and d in S such that a + b + c + d = target? Find all unique quadruplets in the array which gives the sum of target.

Note: The solution set must not contain duplicate quadruplets.

Idea is to put all the pair sums a in hashmap along with corresponding indexes and once done check if -a is also present in the hashmap. If both a and -a is present and since the question is looking for unique quadruplets then we can just filter out with indexes.

class Solution(object):
    def fourSum(self, arr, target):
        seen = {}
        for i in range(len(arr)-1):
            for j in range(i+1, len(arr)):
                if arr[i]+arr[j] in seen: 
                    seen[arr[i]+arr[j]].add((i,j))
                else: 
                    seen[arr[i]+arr[j]] = {(i,j)}
        result = []
        for key in seen:
            if -key + target in seen:
                for (i,j) in seen[key]:
                    for (p,q) in seen[-key + target]:
                        sorted_index = sorted([arr[i], arr[j], arr[p], arr[q]])
                        if i not in (p, q) and j not in (p, q) and sorted_index not in result:
                            result.append(sorted_index)
        return result
\$\endgroup\$

2 Answers 2

3
\$\begingroup\$
  • Use enumerate rather than range(len(...)) + __getitem__. It is both faster and more readable.
  • To limit items of the second iteration to be "after the current item" you can use itertools.combinations.
  • To avoid the need to check for the special case of "is the item already in the dictionary?", use a collections.defaultdict.
  • You could use a set rather than a list to store the final results and remove yourself the need to check for duplicates
  • -key + target is better written as target - key

import itertools
from collections import defaultdict


def four_sum(array, target):
    seen = defaultdict(set)
    for (i, first), (j, second) in itertools.combinations(enumerate(array), 2):
        seen[first + second].add((i, j))

    result = set()
    for key, first_indices in seen.items():
        second_indices = seen.get(target - key, set())
        for p, q in second_indices:
            for i, j in first_indices:
                # Not reusing the same number twice
                if not ({i, j} & {p, q}):
                    indices = tuple(sorted(array[x] for x in (i, j, p, q)))
                    result.add(indices)
    return result
\$\endgroup\$
2
  • \$\begingroup\$ Yours is actually slower compared to OP's on leetcode, I must agree it is more readable though. Yours: 335ms OP: 239ms. It must return a list, so I've changed it a bit, but still didn;t really expect that. :) \$\endgroup\$
    – Ludisposed
    Commented Nov 27, 2017 at 9:51
  • \$\begingroup\$ Note: The solution set must not contain duplicate quadruplets. Yeah, online judges and their requirements matching their specs… \$\endgroup\$ Commented Nov 27, 2017 at 9:53
2
\$\begingroup\$

Implementation

  • why not build result with condition i < j < p < q?

Algorithm

  • code builds hash map as combination of all indexes from nums. Combination of all unique values from nums (or index or unique values) is better choice. Case: fourSum([0 for x in range(n)], 0)
  • code builds hash map with integers from nums which can't be added to result. Case: fourSum([x for x in range(1, n, 1)], 0)
  • code check if for key from hash map also target - key exists in final loop, can earlier. Case: fourSum([x for x in range(0, n*10, 10)], n*5+1)
  • You can split hash map for two parts: a,b and c,d pair. Don't change complexity of hash map, but final loop: 1/2 * 1/2 faster

Speedup

  • best: algorithm (big O notation), e.g. reduce O(n^2) memory to O(n)
  • sometimes good: algorithm constants, e.g. split hash map for first and second pair
  • bad: dirty, low-level language speed-up constants, e.g. replace itertools.combinations with directly loops. This is anti-pattern. Reasons: less understandable, maintainable, changeable and paradoxically slower. Slower because bottlenecks are usually caused by cascade several algorithms, e.g. O(n^3) * O(n^3). With clean code easier reduce problem to O(n^5) or less. With dirty code usually at the end we get O(n^6) with small const

Code (the same O(n^2) mem)

from itertools import combinations
from collections import defaultdict, Counter

def fourSum(self, nums, target):
    if len(nums) < 4:
        return []
    half_target = target // 2
    counter = Counter(nums)
    uniques_wide = sorted(counter)
    x_min, x_max = target - 3 * uniques_wide[-1], target - 3 * uniques_wide[0] # bad
    uniques = [ x for x in uniques_wide if x_min <= x <= x_max ]
    duplicates = [x for x in uniques if counter[x] > 1]

    target_minus_xy_sums = set(target - x - y for x, y in combinations(uniques, 2))
    target_minus_xy_sums |= set(target - x - x for x in duplicates)

    ab_sum_pairs, cd_sum_pairs = defaultdict(list), defaultdict(list)
    for (x, y) in combinations(uniques, 2):
        if x + y in target_minus_xy_sums:
            if x + y <= half_target:
                ab_sum_pairs[x + y].append((x, y))
            if x + y >= half_target:
                cd_sum_pairs[x + y].append((x, y))
    for x in duplicates:
        if x + x in target_minus_xy_sums:
            if x + x <= half_target:
                ab_sum_pairs[x + x].append((x, x))
            if x + x >= half_target:
                cd_sum_pairs[x + x].append((x, x))

    return [[a, b, c, d]
            for ab in ab_sum_pairs
            for (a, b) in ab_sum_pairs[ab]
            for (c, d) in cd_sum_pairs[target - ab]
            if b < c or b == c and [a, b, c, d].count(b) <= counter[b]] # if bi < ci
\$\endgroup\$

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.