Background

You can solve every problem with recursion. Sometimes it’s easier, sometimes it is harder. Some competitive programming people will encourage you to avoid recursion at all costs because of the stack frame overhead and things like depth limits. In this article I want to cover some patterns for how to structure your recursions, and ideas to go over.

The concept is that you start writing a recursive function, but then you aren’t sure about how to structure it. Should the top return nothing, then you alter a global state? Should each node have a return value, then tail up? Do you need a cache, or check for boundaries? Should you do operations going into the state, or out of the state? Should you use a mask, or indices?

Method

Tail recursion

This is when you return just a single value from each state,

def my_sum(items: List[int]) -> int:
    if not items: return 0
    val = items.pop()
    return sum(items) + val

This gives you a clear image of your state each value down, but there’s a big issue. This is not a tail recursion because it needs to travel back up to give you the solution. It’s going to recurse all the way down, hit zero, then climb back up and add your numbers.

my_sum(5, 4, 3, 2, 1)
5 + my_sum(4, 3, 2, 1)
5 + (4 + my_sum(3, 2, 1))
5 + (4 + (3 + my_sum(2, 1)))
5 + (4 + (3 + (2 + my_sum(1))))
5 + (4 + (3 + (2 + (1 + my_sum()))))
5 + (4 + (3 + (2 + (1 + 0))))
5 + (4 + (3 + (2 + 1)))
5 + (4 + (3 + 3))
5 + (4 + 6)
5 + 10
15

Not very appealing. Tail recursion is nicer, for performance reasons

def my_sum_tail(items: List[int]) -> int:
    if not items: return 0
    if len(items) == 1: return items[0]
    val = items.pop()
    items[-1] += val
    return sum(items)

You can think of reduction functions like this in programming as well

items = [5, 4, 3, 2, 1]
answer = reduce(lambda acc, cur: acc + cur, items)
my_sum(5, 4, 3, 2, 1)
my_sum(5, 4, 3, 3)
my_sum(5, 4, 6)
my_sum(5, 10)
my_sum(15)

The reason why compilers like this is obvious, your compiler can do a tail call optimisation. It’s going to turn your function into a loop, rather than recursion and it won’t need your improvements. The more common way is to structure it like that reduce, with an accumulator argument

def my_sum_2(items: List[int], acc: int=0):
    if not items: return acc
    return my_sum_2(items, acc + items.pop())

This sort of operation is common in pure functional languages like lisp or Haskell. A number of non-functional programming languages won’t have this compiler optimisation, so I would not count on it always being accelerated unless you know your compiler.

DFS traversal orders

This is pretty classical, this is just where do you do your operation?

from dataclasses import dataclass
from typing import Optional
 
@dataclass
class TreeNode:
    val: int
    left: Optional['TreeNode']
    right: Optional['TreeNode']
 
def dfs(node):
    if not node: return 0
 
    # if you do an operation here, it'll be pre-order traversal
    left = dfs(node.left)
    # if you do an operation here, it'll be in-order traversal
    right = dfs(node.right)
    # if you do an operation here, it'll be post-order traversal
    
    return left + right + node.val

You can also easily see why they’re called pre-order, in-order and post-order!

Tree graph

If you’re given a tree graph you can track the parent of your current node and skip the visited set!

adj_list: List[List[int]] = [ . . . ]
def dfs(cur: int, parent: int):
    for neigh in adj_list[cur]:
        if neigh != parent:
            dfs(neigh, cur)

Multi return values

This one is fairly obvious, and all the recursions that can return a None are technically also doing this. You can return multiple values. This tends to happen in DP trees. Here you need the maximum loot from a tree, but you cannot rob directly connected nodes. So you need to return your own sum, and the node’s children’s sums. This can also help with colouring problems.

def rob_sub(node: Optional[TreeNode]) -> tuple[int, int]:
    if not node: 
        return (0, 0)
    
    left = rob_sub(node.left)
    right = rob_sub(node.right)
    
    rob_current = node.val + left[1] + right[1]
    skip_current = max(left) + max(right)
    
    return (rob_current, skip_current)

Caching

If you put a cache on top of your recursion, it almost instantly becomes a solution for a dynamic programming problem!

cache = {}
def dp(i, j):
    key = (i, j)
    if key in cache:
        return cache[key]
    
    # ... do your operations
    
    cache[key] = ans
    return ans

Side note, there is also the advice of using @cache from functools.

from functools import cache
@cache
def dp(i, j):
    # do your operation
    return ans

Whenever you use this in an interview, your interviewer will probably ask you whether you know how to memoize it. So… I usually don’t use this cache operator at all because it can be treated as a hack. It’s actually trivial to make a key for your cache anyways, so you can avoid wasting time with the discussion entirely by not using it!

Backtracking

Backtracking is a regular DFS, but you have some global state. When you enter a child you want to do some operation on that global state, and when you exit you want to undo it.

ans = []
state = [5, 4, 3, 2, 1]
 
def permutations(idx):
    if idx == len(state):
        ans.append([x for x in state])
        return
        
    for i in range(idx + 1, n):
        # apply operation 
        state[i], state[idx] = state[idx], state[i]
        permutations(idx + 1)
        # undo operation 
        state[i], state[idx] = state[idx], state[i]

For the above this is a global state, but you can also pass it in as a parameter:

def permutations(idx, state):
    if idx == len(state):
        ans.append([x for x in state])
        return
        
    for i in range(idx + 1, n):
        # apply operation 
        state[i], state[idx] = state[idx], state[i]
        permutations(idx + 1, state)
        # undo operation 
        state[i], state[idx] = state[idx], state[i]

As long as it’s a reference and not a value, then this is fine. Usually I don’t do this because in my head more function arguments == bad, and it’s global state anyways. Your state can be any data structure. Frequency maps and deques are common. Usually these are quite obvious… until they aren’t…

Tracking used values

Let’s say you want to build permutations by adding one number at a time from your available numbers. There are two ways to track this, either a set of unused numbers or a mask. Usually the mask approach will give you a better latency, but that doesn’t always mean it’s the best approach.

Mask approach

def perms(values: List[int], mask: List[bool], state: List[int]):
    if all(mask):
        return [state[:]]
    
    ans = []
    for i, val in enumerate(values):
        if mask[i]: continue
        mask[i] = True
        state.append(val)
        ans += perms(values, mask, state)
        state.pop()
        mask[i] = False
    return ans

Set approach:

def perms(values: Set[int], state: List[int]):
    if not values:
        return [state]
    
    ans = []
    for val in values:
        # expensive! you're creating new lists every iteration, 
        # so it's O(n)
        ans += perms(values - {val}, state + [val])
        
    return ans
 
def perms(values: Set[int], state: List[int]):
    if not values:
        return [state]
    
    ans = []
    # 2*O(n) before of list(values)!
    for val in list(values):
        values.discard(val)
        state.append(val)
        ans += perms(values, state)
        state.pop()
        values.add(val)
        
    return ans

Now, this has issues. What if your initial values has duplicate elements? Also the time complexity just massively increased in one solution! You could argue for permutations your time complexity is high anyways, so it’s fine. Still, not good! For the duplication problem, you can store indices in your set instead.

def perms(values: List[int], indices: Set[int], state: List[int]):
    if not values:
        return [state]
    
    ans = []
    for i in enumerate(list(indices)):
        values.discard(i)
        state.append(values[i])
        ans += perms(values, state)
        state.pop()
        values.add(i)
        
    return ans

Or you can change over into a frequency map

def perms(freqs: Dict[int, int], state: List[int]):
    if not values:
        return [state]
    
    ans = []
    for val in freqs.keys():
        if freqs[val] == 0: continue
        freqs[val] -= 1
        state.append(values[i])
        
        ans += perms(values, state)
        
        state.pop()
        freqs[val] += 1
        values.discard(i)
        
    return ans

You can pick depending on the complexity of the problem at-hand, but probably avoid creating a new data structure every for-loop.

Splitting state

This is an approach to permutations where you split your mask every loop. It’s particularly useful for tree-building, such as unique binary search trees:

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
 
# needs nums to be sorted
def generateTrees(nums: list[int]) -> list[TreeNode]:
    if not nums:
        return [None]
    
    all_trees = []
    
    for i in range(len(nums)):
        root_val = nums[i]
        
        left_subtrees = generateTrees(nums[:i])
        right_subtrees = generateTrees(nums[i+1:])
        
        for l_tree in left_subtrees:
            for r_tree in right_subtrees:
                current_tree = TreeNode(root_val)
                current_tree.left = l_tree
                current_tree.right = r_tree
                all_trees.append(current_tree)
                
    return all_trees

Or quicksort

def quicksort(arr):
    if len(arr) <= 1: return arr
    pivot = arr[0]
    return quicksort(
        [x for x in arr[1:] if x < pivot]) + \
        [pivot] + \
        quicksort([x for x in arr[1:] if x >= pivot])

You can be an animal and write it on one line!

qs = lambda l: qs([x for x in l[1:] if x <= l[0]]) + [l[0]] + qs([x for x in l[1:] if x > l[0]]) if l else []

Intervals/ranges

This can be an acceleration on the split state section if it is never reordered. You don’t need to create arrays constantly. The classic example is merge sort:

from typing import List
 
def merge_sort(arr: List[int], start: int, end: int) -> None:
    if start >= end:
        return
    mid = (start + end) // 2
    merge_sort(arr, start, mid)
    merge_sort(arr, mid + 1, end)
    merge(arr, start, mid, end)
 
def merge(arr: List[int], start: int, mid: int, end: int):
    left_half = arr[start : mid + 1]
    right_half = arr[mid + 1 : end + 1]
    
    i = j = 0
    k = start
    
    while i < len(left_half) and j < len(right_half):
        if left_half[i] <= right_half[j]:
            arr[k] = left_half[i]
            i += 1
        else:
            arr[k] = right_half[j]
            j += 1
        k += 1
        
    while i < len(left_half):
        arr[k] = left_half[i]
        i += 1
        k += 1
    while j < len(right_half):
        arr[k] = right_half[j]
        j += 1
        k += 1

Building segment trees also falls into this same category. You have a set of numbers, and every recursion down you’re changing the range that you’re considering.

def build_segment_tree(arr: List[int], tree: List[int], node: int, start: int, end: int):
    if start == end:
        tree[node] = arr[start]
        return
 
    mid = (start + end) // 2
    build_segment_tree(arr, tree, 2 * node, start, mid)
    build_segment_tree(arr, tree, 2 * node + 1, mid + 1, end)
    
    tree[node] = tree[2 * node] + tree[2 * node + 1]

You can also generate all the unique binary trees using this technique as well. It just saves a bit of time on the list building, though.

from dataclasses import dataclass
from itertools import product
from typing import List
 
@dataclass
class TreeNode:
    val: int
    left: 'TreeNode'
    right: 'TreeNode'
 
def build_trees(values: List[int], start: int, end: int) -> List[TreeNode]:
    if start > end:
        return [None]
 
    answers = []
    for i in range(start, end + 1):
        lefts = build_trees(values, start, i - 1)
        rights = build_trees(values, i + 1, end)
        
        for left, right in product(lefts, rights):
            root = TreeNode(values[i], left, right)
            answers.append(root)
 
    return answers

Flattening

You can flatten nested lists or json-like structures quite easily too:

def deep_flatten(items):
    result = []
    for item in items:
        if isinstance(item, list):
            # Recurse only on specific type
            result.extend(deep_flatten(item))
        else:
            result.append(item)
    return result

Flood fills

This is more of a grid problem, but it’s also possible within graphs. You have some boundary rule about where you can stop, and you need to keep infecting or flooding until you hit that boundary.

def flood_fill(grid, r, c, new_color, old_color):
    if r < 0 or r >= len(grid) or c < 0 or c >= len(grid[0]):
        return
    if grid[r][c] != old_color:
        return
        
    grid[r][c] = new_color 
    
    for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
        flood_fill(grid, r + dr, c + dc, new_color, old_color)

Multi-function recursion

The classic example of this is a regex parser. You have multiple patterns that you need to handle, and elements can contain each other. This is very common within compilers as well.

class ExpressionParser:
    def __init__(self, tokens):
        self.tokens = tokens
        self.pos = 0
 
    def parse_expression(self):
        # Expression -> Term { + Term }
        left = self.parse_term()
        while self.match('+'):
            right = self.parse_term()
            left = ('ADD', left, right)
        return left
 
    def parse_term(self):
        # Term -> Factor { * Factor }
        left = self.parse_factor()
        while self.match('*'):
            right = self.parse_factor()
            left = ('MUL', left, right)
        return left
 
    def parse_factor(self):
        # Factor -> ( Expression ) | Number
        if self.match('('):
            node = self.parse_expression() # Indirect recursion happens here!
            self.expect(')')
            return node
        return self.consume_number()

Conclusion

In this article we went over a large number of recursion patterns. Since you can write almost any problem with recursion, there are still a huge number of problems that could still be explored. Here is a mix of higher difficulty problems that you can try to finish with recursion:

Break the problem into a current state, and think about how you want to traverse to the next one. Then think about how this code should be written to support that traversal.