Background
Have you ever seen people using functional tools like reduce, product, or accumulate? These can be quite overwhelming, and if you aren’t familiar with them it’s going to make the code look mystical. Most of the time you can get away with only understanding reduce and using list comprehension, but it’s good to gain an understanding of what there is. Python introduces these in their Functional Programming Modules.
You can view this page as a deeper dive into lists from DSA 1 - Syntax and Snippets, and since this is quite code-specific I’ll be writing it in an learn x in y minutes-style tutorial.
Method
Basics
This is mostly a duplicate of content in DSA 1 - Syntax and Snippets. One thing I learned here is what the difference between a deep copy and a shallow copy in Python is. If you make a slice in Rust, this slice will directly point towards the original array’s data. However, in Python list data doesn’t sit contiguously in RAM, so if you make a slice of a list in Python, the slice will contain different pointers to the same underlying data.
>>> from dataclasses import dataclass
>>> import copy
>>>
>>> @dataclass
... class Number:
... num: int = 0
...
>>> original = [Number(x) for x in range(3)]
>>> shallow = original[:]
>>> deep = copy.deepcopy(original)
>>>
>>> shallow[2].num = -500
>>> print(f'original = {original}')
original = [Number(num=0), Number(num=1), Number(num=-500)]
>>> print(f'shallow = {shallow}')
shallow = [Number(num=0), Number(num=1), Number(num=-500)]
>>> print(f'deep = {deep}')
deep = [Number(num=0), Number(num=1), Number(num=2)]
>>>
>>> primitive = [1, 2, 3, 4, 5]
>>> shallow = primitive[:]
>>> shallow[0] = 500
>>> print(primitive)
[1, 2, 3, 4, 5]
>>> print(shallow)
[500, 2, 3, 4, 5]The primitive list making a true copy is a bit strange though. If it was being honest, it should probably still make a shallow copy in the same way as the list of references because these primitive lists are still lists of references. If we use id() to find the pointer address of values in these lists we can try to understand what’s going on. It seems true references are scattered randomly, but it tries to put primitives near each other. If you massively increase one number, its going to get allocated to a new address. If you do a shallow copy, the shallow copy will share the same pointers until you assign a number a new value.
>>> from itertools import pairwise
>>> from dataclasses import dataclass
>>> import copy
>>> from math import factorial
>>>
>>> @dataclass
... class Number:
... num: int = 0
...
>>> refs = [Number(i) for i in range(10)]
>>> prim = [0, 1, 2, 4, 5, 6, 7, 8, 9]
>>>
>>> refs_addresses = [hex(id(x)) for x in refs]
>>> print(refs_addresses)
['0x79a343336210', '0x79a3432ba600', '0x79a3432ba570', '0x79a3432ba660', '0x79a3432ba5a0', '0x79a3432ba690', '0x79a3432ba6f0', '0x79a3432bad50', '0x79a3432b8c20', '0x79a3432ba6c0']
>>>
>>> prim_addresses = [hex(id(x)) for x in prim]
>>> print(prim_addresses)
['0xb36088', '0xb360a8', '0xb360c8', '0xb36108', '0xb36128', '0xb36148', '0xb36168', '0xb36188', '0xb361a8']
>>>
>>> diffs1 = [id(x) - id(y) for x, y in pairwise(refs)]
>>> print(diffs1)
[506896, 144, -240, 192, -240, -96, -1632, 8496, -6816]
>>>
>>> diffs2 = [id(x) - id(y) for x, y in pairwise(prim)]
>>> print(diffs2)
[-32, -32, -64, -32, -32, -32, -32, -32]
>>> >>> prim[0] = factorial(1000)
>>> diffs2 = [id(x) - id(y) for x, y in pairwise(prim)]
>>> print(diffs2)
[4914120, -32, -64, -32, -32, -32, -32, -32]
>>>
>>> cpy = prim[:]
>>> print([id(x) - id(y) for x, y in zip(prim, cpy)])
[0, 0, 0, 0, 0, 0, 0, 0, 0]
>>> cpy[1] = 5
>>> print([id(x) - id(y) for x, y in zip(prim, cpy)])
[0, -128, 0, 0, 0, 0, 0, 0, 0]
>>> cpy[1] = factorial(1000)
>>> print([id(x) - id(y) for x, y in zip(prim, cpy)])
[0, -4915304, 0, 0, 0, 0, 0, 0, 0]
>>>That chunk above is a bit technical though, don’t worry if you’re confused since there are things I haven’t introduced yet.
# - - - - - - - - - - bare basics - - - - - - - - - -
# declare some list
items = [1, 2, 3, 4, 5]
# basic iteration
for item in items:
print(item)
for i in range(len(items)):
print(items[i])
for i, item in enumerate(items):
print(i, item)
# this can also be nested
items = [(1, 2), (3, 4), (5, 6)]
for i, (x, y) in enumerate(items):
print(i, x, y)
# - - - - - - - - - - slicing - - - - - - - - - -
items = [1, 2, 3, 4, 5]
# prints this
# [0, 1] [2, 3, 4, 5] [3, 4]
print(items[:2], items[1:], items[2:4])
# slicing creates shallow copies, so you can clone the list like
cloned = items[:]
reverse_items = items[::-1]
# - - - - - - - - - - comprehension - - - - - - - - - -
items = [1, 2, 3, 4, 5]
above_two = [item for item in items if item > 2]
squared = [item * item for item in items]
packed = [(item, i) for i, item in enumerate(items)]
if_else = ['even' if x % 2 == 0 else 'odd' for x in items]
items = [[1, 2], [3, 4], [5, 6]]
flat = [num for row in items for num in row]
flat_evens = [num for row in items for num in row if num % 2 == 0]
identity = [[1 if r == c else 0 for c in range(3)] for r in range(3)]
# walrus operator to avoid duplicate operations
# you can assign values inside of expressions with it
items = [' apple', 'banana', ' ', 'pear ']
clean = [c for s in items if len(c := s.strip()) > 0]
# ['apple', 'banana', 'pear']
# you'd have to do this if it didn't exist
clean = [c.strip() for s in items if len(s.strip()) > 0]
# - - - - - - - - - - boolean reductions - - - - - - - - - -
flags = [True, False, True]
print(any(flags)) # => True
print(all(flags)) # => False
# Very powerful with generator expressions
nums = [2, 4, 6, 8]
is_all_even = all(n % 2 == 0 for n in nums) # => True
# - - - - - - - - - - zip function - - - - - - - - - -
list1 = [1, 2, 3, 4, 5]
list2 = [5, 6, 7, 8, 9]
for a, b in zip(list1, list2):
print(a, b) # 1 5\n 2 6\n 2 7\n ...
# - - - - - - - - - - Extended Iterable Unpacking (*) - - - - - - - - - -
# the issue with these is that they're O(n) because they create new lists
# in real FP languages, they wouldn't duplicate
# so you can repeatedly grab the head and onion peel the list to
# iterate through it...
items = [1, 2, 3, 4, 5]
# 1. Head and Tail
head, *tail = items
print(head) # => 1
print(tail) # => [2, 3, 4, 5]
# 2. last element
*rest, last = items
print(rest) # => [1, 2, 3, 4]
print(last) # => 5
# 3. Grab the ends, pack the middle
first, *middle, last = items
print(middle) # => [2, 3, 4]
# 4. Deep unpacking (Nested)
# If you have a list inside a list, you can match the structure
data = [1, [2, 3, 4], 5]
start, (inner_head, *inner_rest), end = data
print(inner_head) # => 2
print(inner_rest) # => [3, 4]There’s some more basics in DSA 1 - Syntax and Snippets.
Itertools - Infinite iterators
Itertools splits itself into distinct sections, and the first are the infinite iterators. A lot of Haskell functions work like this by default, but in Python it needs to be quite explicit.
from itertools import *
# count(start, step)
counter = count(10, 2)
print(next(counter)) # => 10
print(next(counter)) # => 12
# cycle(iterable)
items = [1, 2]
cycler = cycle(items)
print(next(cycler)) # => 1
print(next(cycler)) # => 2
print(next(cycler)) # => 1
# repeat(elem, n)
repeater = repeat('Hello', 3)
print(list(repeater)) # ['Hello', 'Hello', 'Hello']Itertools - Combinatoric iterators
from itertools import *
# product is the same as multi dimensional for loops
list1, list2, list3 = [1, 2], ['a', 'b'], [True, False]
prod = product(list1, list2)
prod = product(list1, list2, list3)
# same as writing
for x, y, z in prod:
pass
for x in list1:
for y in list2:
for z in list3:
pass
# permutations(iteratble, r)
perms = permutations([1, 2, 3], 2)
print(list(perms)) # => [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]
# permutations(iterable, r): Possible orderings, no repeated elements
perms = itertools.permutations([1, 2, 3], 2)
print(list(perms)) # => [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]
# combinations(iterable, r): Sorted order, no repeated elements
combs = itertools.combinations([1, 2, 3, 4], 2)
print(list(combs)) # => [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
# combinations_with_replacement(iterable, r): Elements can repeat
combs_w_r = itertools.combinations_with_replacement([1, 2], 2)
print(list(combs_w_r)) # => [(1, 1), (1, 2), (2, 2)]Itertools - Terminating Iterators
from itertools import *
import operator
# ==========================================
# 1. Aggregation & Slicing
# ==========================================
# accumulate(iterable, func=operator.add): Running totals (or other reduction)
# Unlike reduce(), it yields every intermediate step.
nums = [1, 2, 3, 4]
running_sum = accumulate(nums)
print(list(running_sum)) # => [1, 3, 6, 10]
running_max = accumulate([1, 5, 2, 8, 3], func=max)
print(list(running_max)) # => [1, 5, 5, 8, 8]
# batched(iterable, n): Chunks data into tuples of length n (Python 3.12+)
# Handles the last chunk gracefully even if it's shorter than n.
data = 'ABCDEFG'
batches = batched(data, 3)
print(list(batches)) # => [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
# islice(iterable, start, stop, step): Slicing for iterators (No data copying!)
# islice(count(), 0, 10, 2) is efficient; data[:10:2] would choke on infinite lists.
sliced = islice(count(), 0, 10, 2)
print(list(sliced)) # => [0, 2, 4, 6, 8]
# pairwise(iterable): Sliding window of size 2 (Python 3.10+)
# 'ABC' -> (A, B), (B, C)
# Great for calculating deltas: [y - x for x, y in pairwise(data)]
print(list(pairwise('ABC'))) # => [('A', 'B'), ('B', 'C')]
# ==========================================
# 2. Chaining & Splitting
# ==========================================
# chain(*iterables): Links multiple sequences together
chained = chain([1, 2], [3, 4])
print(list(chained)) # => [1, 2, 3, 4]
# chain.from_iterable(iterable): Flattens one level of nesting (Fast)
# Equivalent to: (x for sublist in data for x in sublist)
matrix = [[1, 2], [3, 4], [5, 6]]
flattened = chain.from_iterable(matrix)
print(list(flattened)) # => [1, 2, 3, 4, 5, 6]
# tee(iterable, n=2): Splits one iterator into n independent iterators
# CAUTION: Once you tee, DO NOT touch the original iterator again.
data = [1, 2, 3]
iter1, iter2 = tee(data, 2)
print(list(iter1)) # => [1, 2, 3]
print(list(iter2)) # => [1, 2, 3]
# zip_longest(*iterables, fillvalue=None): Zip that doesn't stop at the shortest
zipped = zip_longest('AB', '123', fillvalue='-')
print(list(zipped)) # => [('A', '1'), ('B', '2'), ('-', '3')]
# ==========================================
# 3. Filtering
# ==========================================
# compress(data, selectors): Filters data using a boolean mask
# Very useful when paired with numpy-style boolean arrays
values = ['A', 'B', 'C', 'D']
mask = [1, 0, 1, 0] # Keep A and C
print(list(compress(values, mask))) # => ['A', 'C']
# filterfalse(pred, seq): The opposite of filter(). Keeps items where pred is False.
# "Keep the odds" (because x % 2 == 0 is False for odds)
odds = filterfalse(lambda x: x % 2 == 0, [1, 2, 3, 4])
print(list(odds)) # => [1, 3]
# dropwhile(pred, seq): Drops items *until* predicate becomes False, then takes REST
# It only checks the start of the list.
dropped = dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1])
print(list(dropped)) # => [6, 4, 1] (Stops dropping once it hits 6)
# takewhile(pred, seq): Takes items *until* predicate becomes False, then STOPS
taken = takewhile(lambda x: x < 5, [1, 4, 6, 4, 1])
print(list(taken)) # => [1, 4] (Stops taking once it hits 6)
# ==========================================
# 4. Transformation
# ==========================================
# starmap(func, seq): Applies func(*args) to items.
# "starmap" = "star unpacking map" -> func(*item)
values = [(2, 5), (3, 2)] # 2^5, 3^2
print(list(starmap(pow, values))) # => [32, 9]Itertools - Group by
This one is useful, but confusing/hard to use.
# groupby(iterable, key=None): Groups consecutive elements with the same key.
# IMPORTANT: The input iterable MUST be sorted by the key first!
data = [{'role': 'admin', 'name': 'Alice'},
{'role': 'user', 'name': 'Bob'},
{'role': 'admin', 'name': 'Charlie'}]
# 1. Sort first
data.sort(key=lambda x: x['role'])
# 2. Group
# groups are iterators, so we usually cast them to lists immediately
for role, group in groupby(data, key=lambda x: x['role']):
print(role, list(group))
# => admin [{'role': 'admin', 'name': 'Alice'}, {'role': 'admin', 'name': 'Charlie'}]
# => user [{'role': 'user', 'name': 'Bob'}]Functools - Caching
Using this usually feels like cheating or a weird shorthand
# lru_cache(maxsize=128): Memoization decorator (Least Recently Used cache)
@functools.lru_cache(maxsize=None)
def fib(n):
if n < 2:
return n
return fib(n-1) + fib(n-2)
start = time.time()
print(fib(30)) # => 832040 (Computes instantly due to caching)Functools - Partial functions
Probably mostly pointless
# -- Partial Functions --
# partial(func, *args, **keywords): Freeze some arguments of a function
def power(base, exponent):
return base ** exponent
# Create a new function that always squares its input
square = functools.partial(power, exponent=2)
print(square(10)) # => 100
# Create a new function that always has base 2
base_two = functools.partial(power, base=2)
print(base_two(5)) # => 32Functools - Functional tools
reduce() is really the core function from here that’s important
# reduce(func, iterable[, initializer]): Apply function cumulatively
# (((1+2)+3)+4)
total = functools.reduce(lambda x, y: x + y, [1, 2, 3, 4])
print(total) # => 10
# reduce with initializer
total_init = functools.reduce(lambda x, y: x + y, [1, 2, 3, 4], 10)
print(total_init) # => 20
# cmp_to_key(func): Convert old-style comparison function to key function
# Comparison func returns -1, 0, or 1
def compare_len(s1, s2):
return len(s1) - len(s2)
words = ['apple', 'banana', 'pear', 'grape']
# Sort by length using a comparison function converted to a key
sorted_words = sorted(words, key=functools.cmp_to_key(compare_len))
print(sorted_words) # => ['pear', 'apple', 'grape', 'banana']Match statements
Python 3.9 is now deprecated, and Python 3.10 introduced match statements. So they’re safe to use in (almost) all contexts!
import random
def run(command):
val = random.randint(1, 10)
match command:
case 'start':
print('starting...')
case 'run':
print('running...')
case 'stop':
print('stopping...')
case ['move', direction]:
print(f'going {direction}...')
case 'up' | 'down' | 'left' | 'right':
print(f'Turning {command}')
case 'lucky' if val == 1:
print(f'You got lucky!')
case 'lucky':
print(f'try next time...')
case _:
print('unknown command!')
run('start')
run(['move', 'south'])
run('invalid command')
data = [1, 'one', True]
for item in data:
match item:
case int():
print(f"Integer: {item}")
case str():
print(f"String: {item}")
case _:
print(f"Other type: {item}")Where to practice these
I find that advent of code problems are really good for these. Most of the pain in advent of code is the parsing, or in other words iterating through 2D grids in interesting ways. Otherwise, start integrating them into regular problem solving!
| Category | Function / Tool | Quick Description / Usage |
|---|---|---|
| Infinite | count(start, step) | 10, 12, 14... (Infinite counter) |
cycle(iter) | A, B, A, B... (Repeats sequence forever) | |
repeat(elem, n) | X, X, X... (Repeats element n times) | |
| Combinatorics | product(a, b) | Cartesian product (Nested loops replacement) |
permutations(p, r) | All orderings, no repeats | |
combinations(p, r) | Sorted order, no repeats | |
| Flow Control | chain(a, b) | Joins iterables: [1,2], [3,4] -> 1, 2, 3, 4 |
chain.from_iterable(m) | Flattens one level: [[1], [2]] -> 1, 2 | |
zip_longest(a, b) | Zips to longest length (fills missing with None) | |
islice(iter, start, stop) | Slicing for iterators (no copy) | |
batched(iter, n) | Chunk into tuples: ABC... -> (A,B), (C...) | |
pairwise(iter) | Sliding window: ABC... -> (A,B), (B,C) | |
tee(iter, n) | Splits one iterator into n independent iterators | |
| Filtering | compress(data, bools) | Filter data using a boolean mask list |
filterfalse(pred, iter) | Keep items where predicate is False | |
dropwhile(pred, iter) | Drop items until predicate returns False | |
takewhile(pred, iter) | Take items while predicate returns True | |
| Reduction | reduce(func, iter) | Fold list into single value: ((a+b)+c) |
accumulate(iter) | Running totals: [1, 2, 3] -> [1, 3, 6] | |
| Grouping | groupby(iter, key) | Group consecutive keys (Must sort first!) |
| Functional | starmap(func, iter) | Unpack args: pow, [(2,5)] -> pow(2,5) |
@lru_cache | Decorator to auto-cache function results | |
partial(func, arg) | Pre-fill arguments: f(a, b) -> g(b) | |
| Syntax | *rest | Unpacking: head, *tail = items |
:= | Walrus operator: if (n := len(x)) > 5: | |
match / case | Structural Pattern Matching (Switch statement) |
zyros