Description
Given a binary tree with nodes containing integer values (positive or negative), how can you count up all the paths that equal to some target sum? Note, that the paths do not have to start from the root; it may start from any sub-tree and end at another sub-tree, as long as the path goes down the tree.
The first solution
This problem can be solved by recursively going through each path in a tree and tallying up the sum at each recursion level, checking if the sum matches the target sum. This means that the current sum or the accumulating sum has to be passed down to its recursive calls.
For simplification, we use two recursive functions.
- One function does a recursive call to its subchildren while tallying up the sum.
- Another function calls the above function at the current node and does a recursive call to its subchildren.
Here is a rough pseudocode:
def count_paths_with_sum(node, target_sum):
if not node:
return 0
paths_from_node = count_paths_with_sum_from_node(node, target_sum, 0)
paths_on_left = count_paths_with_sum(node.left, target_sum)
paths_on_right = count_paths_with_sum(node.right, target_sum)
return paths_from_node + paths_on_left + paths_on_right
def count_paths_with_sum_from_node(node, target_sum, current_sum):
if not node:
return 0
current_sum += node.data
paths = 0
if current_sum == target_sum:
paths += 1
paths += count_paths_with_sum_from_node(node.left, target_sum, current_sum)
paths += count_paths_with_sum_from_node(node.right, target_sum, current_sum)
return paths
Note that if leafs in a tree match the target sum, they should also be counted towards the sum paths.
Performance
Since we want to to find the total number of paths to the target sum, and the calculation doesn't necessarily have to always start from the root of the tree, the algorithm above creates some redundancy in calculations. As the paths go deeper down into the tree, many recursive calls will repeat the same calculation for values that have already been gathered.
Figuratively, for each node in the tree, we are trying to calculate all the paths downwards.
What is the time complexity here?
count_paths_with_sum_from_node
will take up to \(O(n)\) because each node in the tree has to be visited once.
count_paths_with_sum
on the other hand, has a few differences. This function calls count_paths_with_sum_from_node
and on top of that, it recursively calls itself. In a really crappy tree, the depth could be at level \(n\). At each level, it calls count_paths_with_sum_from_node
, which already costs \(O(n)\). This means that the total running time complexity is \(O(n \times n) = O(n^2)\).
If the tree in question happens to be balanced, the asymptotic time complexity for the algorithm is roughly \(O(n \log n)\). Why is it not \(O(\log ^2 n)\)? This is because each node has to be visited in count_paths_with_sum_from_node
.
Optimization: Memoization
The term "memoization" was introduced by Donald Michie in the year 1968. It's based on the Latin word memorandum, meaning "to be remembered".
We can use memoization, or alternatively, dynamic programming, to make our algorithm faster.
Basically, we are going to store the running total values into a hash table as we visit each node in the tree. If the values already exist in the hash table, we can look them up in constant time O(1)
.
Here is a working pseudocode in Python.
def count_paths_with_sum(node, target_sum, running_sum, hash_table):
if node is None:
return 0
running_sum += node.value
path_totals = hash_table[running_sum - target_sum]
if running_sum == target_sum:
path_totals += 1
hash_table[running_sum] += 1
path_totals += count_paths_with_sum(node.left, target_sum, running_sum, hash_table)
path_totals += count_paths_with_sum(node.right, target_sum, running_sum, hash_table)
hash_table[running_sum] -= 1
return path_totals
The run-time complexity of this algorithm is \(O(n)\), which is a pretty good improvement!
Why \(O(n)\) you ask? If we analyze the algorithm and ignore the function calls, we see that all operations are \(O(1)\) in time. We then notice that we have the left and right children passed into its recursive calls. Every node has to be visited once, so the overall time complexity is \(O(n)\), while the space complexity is \(O(n)\) for an unbalanced tree. The space complexity can be optimized to \(O(\log n)\) if the tree was balanced.
The space complexity can be a hefty price to pay, since the earlier solution required \(O(1)\) space, but quadratic time complexity is a big no-no!