Background
You can solve every problem with recursion. Sometimes it’s easier, sometimes it is harder. Some competitive programming people will encourage you to avoid recursion at all costs because of the stack frame overhead and things like depth limits. In this article I want to cover some patterns for how to structure your recursions, and ideas to go over.
The concept is that you start writing a recursive function, but then you aren’t sure about how to structure it. Should the top return nothing, then you alter a global state? Should each node have a return value, then tail up? Do you need a cache, or check for boundaries? Should you do operations going into the state, or out of the state? Should you use a mask, or indices?
Method
Tail recursion
This is when you return just a single value from each state,
def my_sum(items: List[int]) -> int:
if not items: return 0
val = items.pop()
return sum(items) + valThis gives you a clear image of your state each value down, but there’s a big issue. This is not a tail recursion because it needs to travel back up to give you the solution. It’s going to recurse all the way down, hit zero, then climb back up and add your numbers.
my_sum(5, 4, 3, 2, 1)
5 + my_sum(4, 3, 2, 1)
5 + (4 + my_sum(3, 2, 1))
5 + (4 + (3 + my_sum(2, 1)))
5 + (4 + (3 + (2 + my_sum(1))))
5 + (4 + (3 + (2 + (1 + my_sum()))))
5 + (4 + (3 + (2 + (1 + 0))))
5 + (4 + (3 + (2 + 1)))
5 + (4 + (3 + 3))
5 + (4 + 6)
5 + 10
15
Not very appealing. Tail recursion is nicer, for performance reasons
def my_sum_tail(items: List[int]) -> int:
if not items: return 0
if len(items) == 1: return items[0]
val = items.pop()
items[-1] += val
return sum(items)You can think of reduction functions like this in programming as well
items = [5, 4, 3, 2, 1]
answer = reduce(lambda acc, cur: acc + cur, items)my_sum(5, 4, 3, 2, 1)
my_sum(5, 4, 3, 3)
my_sum(5, 4, 6)
my_sum(5, 10)
my_sum(15)
The reason why compilers like this is obvious, your compiler can do a tail call optimisation. It’s going to turn your function into a loop, rather than recursion and it won’t need your improvements. The more common way is to structure it like that reduce, with an accumulator argument
def my_sum_2(items: List[int], acc: int=0):
if not items: return acc
return my_sum_2(items, acc + items.pop())This sort of operation is common in pure functional languages like lisp or Haskell. A number of non-functional programming languages won’t have this compiler optimisation, so I would not count on it always being accelerated unless you know your compiler.
DFS traversal orders
This is pretty classical, this is just where do you do your operation?
from dataclasses import dataclass
from typing import Optional
@dataclass
class TreeNode:
val: int
left: Optional['TreeNode']
right: Optional['TreeNode']
def dfs(node):
if not node: return 0
# if you do an operation here, it'll be pre-order traversal
left = dfs(node.left)
# if you do an operation here, it'll be in-order traversal
right = dfs(node.right)
# if you do an operation here, it'll be post-order traversal
return left + right + node.valYou can also easily see why they’re called pre-order, in-order and post-order!
Tree graph
If you’re given a tree graph you can track the parent of your current node and skip the visited set!
adj_list: List[List[int]] = [ . . . ]
def dfs(cur: int, parent: int):
for neigh in adj_list[cur]:
if neigh != parent:
dfs(neigh, cur)Multi return values
This one is fairly obvious, and all the recursions that can return a None are technically also doing this. You can return multiple values. This tends to happen in DP trees. Here you need the maximum loot from a tree, but you cannot rob directly connected nodes. So you need to return your own sum, and the node’s children’s sums. This can also help with colouring problems.
def rob_sub(node: Optional[TreeNode]) -> tuple[int, int]:
if not node:
return (0, 0)
left = rob_sub(node.left)
right = rob_sub(node.right)
rob_current = node.val + left[1] + right[1]
skip_current = max(left) + max(right)
return (rob_current, skip_current)Caching
If you put a cache on top of your recursion, it almost instantly becomes a solution for a dynamic programming problem!
cache = {}
def dp(i, j):
key = (i, j)
if key in cache:
return cache[key]
# ... do your operations
cache[key] = ans
return ansSide note, there is also the advice of using @cache from functools.
from functools import cache
@cache
def dp(i, j):
# do your operation
return ansWhenever you use this in an interview, your interviewer will probably ask you whether you know how to memoize it. So… I usually don’t use this cache operator at all because it can be treated as a hack. It’s actually trivial to make a key for your cache anyways, so you can avoid wasting time with the discussion entirely by not using it!
Backtracking
Backtracking is a regular DFS, but you have some global state. When you enter a child you want to do some operation on that global state, and when you exit you want to undo it.
ans = []
state = [5, 4, 3, 2, 1]
def permutations(idx):
if idx == len(state):
ans.append([x for x in state])
return
for i in range(idx + 1, n):
# apply operation
state[i], state[idx] = state[idx], state[i]
permutations(idx + 1)
# undo operation
state[i], state[idx] = state[idx], state[i]For the above this is a global state, but you can also pass it in as a parameter:
def permutations(idx, state):
if idx == len(state):
ans.append([x for x in state])
return
for i in range(idx + 1, n):
# apply operation
state[i], state[idx] = state[idx], state[i]
permutations(idx + 1, state)
# undo operation
state[i], state[idx] = state[idx], state[i]As long as it’s a reference and not a value, then this is fine. Usually I don’t do this because in my head more function arguments == bad, and it’s global state anyways. Your state can be any data structure. Frequency maps and deques are common. Usually these are quite obvious… until they aren’t…
Tracking used values
Let’s say you want to build permutations by adding one number at a time from your available numbers. There are two ways to track this, either a set of unused numbers or a mask. Usually the mask approach will give you a better latency, but that doesn’t always mean it’s the best approach.
Mask approach
def perms(values: List[int], mask: List[bool], state: List[int]):
if all(mask):
return [state[:]]
ans = []
for i, val in enumerate(values):
if mask[i]: continue
mask[i] = True
state.append(val)
ans += perms(values, mask, state)
state.pop()
mask[i] = False
return ansSet approach:
def perms(values: Set[int], state: List[int]):
if not values:
return [state]
ans = []
for val in values:
# expensive! you're creating new lists every iteration,
# so it's O(n)
ans += perms(values - {val}, state + [val])
return ans
def perms(values: Set[int], state: List[int]):
if not values:
return [state]
ans = []
# 2*O(n) before of list(values)!
for val in list(values):
values.discard(val)
state.append(val)
ans += perms(values, state)
state.pop()
values.add(val)
return ansNow, this has issues. What if your initial values has duplicate elements? Also the time complexity just massively increased in one solution! You could argue for permutations your time complexity is high anyways, so it’s fine. Still, not good! For the duplication problem, you can store indices in your set instead.
def perms(values: List[int], indices: Set[int], state: List[int]):
if not values:
return [state]
ans = []
for i in enumerate(list(indices)):
values.discard(i)
state.append(values[i])
ans += perms(values, state)
state.pop()
values.add(i)
return ansOr you can change over into a frequency map
def perms(freqs: Dict[int, int], state: List[int]):
if not values:
return [state]
ans = []
for val in freqs.keys():
if freqs[val] == 0: continue
freqs[val] -= 1
state.append(values[i])
ans += perms(values, state)
state.pop()
freqs[val] += 1
values.discard(i)
return ansYou can pick depending on the complexity of the problem at-hand, but probably avoid creating a new data structure every for-loop.
Splitting state
This is an approach to permutations where you split your mask every loop. It’s particularly useful for tree-building, such as unique binary search trees:
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
# needs nums to be sorted
def generateTrees(nums: list[int]) -> list[TreeNode]:
if not nums:
return [None]
all_trees = []
for i in range(len(nums)):
root_val = nums[i]
left_subtrees = generateTrees(nums[:i])
right_subtrees = generateTrees(nums[i+1:])
for l_tree in left_subtrees:
for r_tree in right_subtrees:
current_tree = TreeNode(root_val)
current_tree.left = l_tree
current_tree.right = r_tree
all_trees.append(current_tree)
return all_treesOr quicksort
def quicksort(arr):
if len(arr) <= 1: return arr
pivot = arr[0]
return quicksort(
[x for x in arr[1:] if x < pivot]) + \
[pivot] + \
quicksort([x for x in arr[1:] if x >= pivot])You can be an animal and write it on one line!
qs = lambda l: qs([x for x in l[1:] if x <= l[0]]) + [l[0]] + qs([x for x in l[1:] if x > l[0]]) if l else []Intervals/ranges
This can be an acceleration on the split state section if it is never reordered. You don’t need to create arrays constantly. The classic example is merge sort:
from typing import List
def merge_sort(arr: List[int], start: int, end: int) -> None:
if start >= end:
return
mid = (start + end) // 2
merge_sort(arr, start, mid)
merge_sort(arr, mid + 1, end)
merge(arr, start, mid, end)
def merge(arr: List[int], start: int, mid: int, end: int):
left_half = arr[start : mid + 1]
right_half = arr[mid + 1 : end + 1]
i = j = 0
k = start
while i < len(left_half) and j < len(right_half):
if left_half[i] <= right_half[j]:
arr[k] = left_half[i]
i += 1
else:
arr[k] = right_half[j]
j += 1
k += 1
while i < len(left_half):
arr[k] = left_half[i]
i += 1
k += 1
while j < len(right_half):
arr[k] = right_half[j]
j += 1
k += 1Building segment trees also falls into this same category. You have a set of numbers, and every recursion down you’re changing the range that you’re considering.
def build_segment_tree(arr: List[int], tree: List[int], node: int, start: int, end: int):
if start == end:
tree[node] = arr[start]
return
mid = (start + end) // 2
build_segment_tree(arr, tree, 2 * node, start, mid)
build_segment_tree(arr, tree, 2 * node + 1, mid + 1, end)
tree[node] = tree[2 * node] + tree[2 * node + 1]You can also generate all the unique binary trees using this technique as well. It just saves a bit of time on the list building, though.
from dataclasses import dataclass
from itertools import product
from typing import List
@dataclass
class TreeNode:
val: int
left: 'TreeNode'
right: 'TreeNode'
def build_trees(values: List[int], start: int, end: int) -> List[TreeNode]:
if start > end:
return [None]
answers = []
for i in range(start, end + 1):
lefts = build_trees(values, start, i - 1)
rights = build_trees(values, i + 1, end)
for left, right in product(lefts, rights):
root = TreeNode(values[i], left, right)
answers.append(root)
return answersFlattening
You can flatten nested lists or json-like structures quite easily too:
def deep_flatten(items):
result = []
for item in items:
if isinstance(item, list):
# Recurse only on specific type
result.extend(deep_flatten(item))
else:
result.append(item)
return resultFlood fills
This is more of a grid problem, but it’s also possible within graphs. You have some boundary rule about where you can stop, and you need to keep infecting or flooding until you hit that boundary.
def flood_fill(grid, r, c, new_color, old_color):
if r < 0 or r >= len(grid) or c < 0 or c >= len(grid[0]):
return
if grid[r][c] != old_color:
return
grid[r][c] = new_color
for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
flood_fill(grid, r + dr, c + dc, new_color, old_color)Multi-function recursion
The classic example of this is a regex parser. You have multiple patterns that you need to handle, and elements can contain each other. This is very common within compilers as well.
class ExpressionParser:
def __init__(self, tokens):
self.tokens = tokens
self.pos = 0
def parse_expression(self):
# Expression -> Term { + Term }
left = self.parse_term()
while self.match('+'):
right = self.parse_term()
left = ('ADD', left, right)
return left
def parse_term(self):
# Term -> Factor { * Factor }
left = self.parse_factor()
while self.match('*'):
right = self.parse_factor()
left = ('MUL', left, right)
return left
def parse_factor(self):
# Factor -> ( Expression ) | Number
if self.match('('):
node = self.parse_expression() # Indirect recursion happens here!
self.expect(')')
return node
return self.consume_number()Conclusion
In this article we went over a large number of recursion patterns. Since you can write almost any problem with recursion, there are still a huge number of problems that could still be explored. Here is a mix of higher difficulty problems that you can try to finish with recursion:
- https://leetcode.com/problems/binary-tree-cameras/description/
- https://leetcode.com/problems/remove-invalid-parentheses/description/
- https://leetcode.com/problems/longest-increasing-path-in-a-matrix/description/
- https://leetcode.com/problems/regular-expression-matching/description/
- https://codeforces.com/problemset/problem/1363/E
- https://codeforces.com/problemset/problem/837/D
- https://codeforces.com/problemset/problem/1527/D
- https://usaco.org/index.php?page=viewproblem2&cpid=648
- https://usaco.org/index.php?page=viewproblem2&cpid=1018
Break the problem into a current state, and think about how you want to traverse to the next one. Then think about how this code should be written to support that traversal.
zyros