In 3 - My Interview Cookbook I mentioned a handful of “low level design” questions. To be fair, the type of low level design questions I’m referring to is a bit different to the “parking lot” design and others on hello interview’s LLD section or questions that you’ll find online.
It’s worth reading the LLD section I wrote in 3 - My Interview Cookbook because that will give the background to how these interviews work, in my experience. This page is dedicated as a set of questions you can practice. At some point in the future, I might launch something like zyros.dev/learn (might not be a real URL, sorry) where I host these on an interactive editor in a more friendly way. That being said, these interviews are heavily in the direction of being focused on your ability to communicate.
A number of people I’ve met don’t think these are worth practicing, and it’s just ngmi if you can’t make it. I don’t really believe in that. Like anything, it’s a skill you can build.
Disclaimer Basically all of these questions are generated by an LLM. Yep. Sorry. I can explain how I did this though. I have some other set of questions I have in mind (vague ideas, like ratelimiters) then described in detail what I want. I then had it generate 80 questions, and write 1-3 solutions and alternative solutions for each question. Finally, I went through and hand-picked these questions out of that bank as ones that I liked.
So, while it is heavily LLM-driven, I also human-reviewed it enough so it isn’t slop. It’s quite hard to find high quality sets of questions for this. Maybe at some point someone will consolidate it into patterns or something, but I dunno.
Questions
Question 1 - Skip List
Difficulty: 9 / 10
Approximate lines of code: 100 LoC
Tags: data-structures
Description
A skip list is a probabilistic data structure that provides O(log n) expected time for search, insert, and delete operations, similar to a balanced BST, but using randomization instead of complex rebalancing logic. It consists of multiple levels of sorted linked lists stacked on top of each other. The bottom level (level 0) contains all elements. Each higher level contains a random subset of the elements below it, roughly half as many.
The key insight is the search pattern: start at the highest level and move right until you would overshoot, then drop down a level and repeat. This “express lane” approach lets you skip over many elements. The randomized level assignment (geometric distribution with p=0.5) ensures that on average, level k has half as many nodes as level k-1, giving O(log n) expected search time.
Part A: Structure and Search
Problem: Part A
Build the skip list structure with a sentinel header node and implement search(key). The header has forward pointers at all levels. To search, start at the highest level and traverse right while the next node’s key is less than the target, then drop down.
sl = SkipList()sl.insert(3)sl.insert(6)sl.insert(7)sl.insert(9)sl.insert(12)# Possible structure (levels are random):# Level 2: HEAD -----------------> 6 -----------------> 12 --> None# Level 1: HEAD ------> 3 ------> 6 ------> 9 ------> 12 --> None# Level 0: HEAD -> 3 -> 6 -> 7 -> 9 -> 12 --> None# Search for 9:# Level 2: HEAD -> 6 (6 < 9, advance) -> 12 (12 > 9, drop down)# Level 1: 6 -> 9 (found!)# Returns Truesl.search(9) # Truesl.search(8) # False - would stop between 7 and 9 at level 0
Part B: Insert
Problem: Part B
To insert a key, first determine its random level using a geometric distribution (flip coins until tails). As you search for the insertion point, track which nodes at each level will need their forward pointers updated. After finding the position, splice in the new node at all levels up to its assigned level.
sl = SkipList(max_level=4, p=0.5)# Before inserting 8:# Level 1: HEAD -> 6 ---------> 12 --> None# Level 0: HEAD -> 6 -> 9 -> 12 --> Nonesl.insert(8) # Randomly gets level 1# The "update" array tracks predecessors at each level:# update[1] = HEAD (at level 1, HEAD.forward[1] = 6, but 6 < 8, so we advance...# 6.forward[1] = 12, but 12 > 8, so update[1] = 6)# update[0] = 6 (at level 0, we stop at 6 since 9 > 8)# After insert:# Level 1: HEAD -> 6 -> 8 ---------> 12 --> None# Level 0: HEAD -> 6 -> 8 -> 9 -> 12 --> None# New node's forward pointers:# new_node.forward[0] = 9 (was update[0].forward[0])# new_node.forward[1] = 12 (was update[1].forward[1])
Part C: Delete
Problem: Part C
Find the node using the same traversal pattern, tracking update nodes at each level. Remove the node by updating each predecessor’s forward pointer to skip over it. If the deleted node was at the highest level and that level is now empty, decrease the skip list’s level.
sl = SkipList()# Level 2: HEAD -> 6 -----------------> 19 --> None# Level 1: HEAD -> 6 ------> 12 ------> 19 --> None# Level 0: HEAD -> 6 -> 9 -> 12 -> 17 -> 19 --> Nonesl.delete(12)# update array: update[0] = node(9), update[1] = node(6)# Unlink at each level where update[i].forward[i] == target# After delete:# Level 2: HEAD -> 6 -----------------> 19 --> None# Level 1: HEAD -> 6 -----------------> 19 --> None# Level 0: HEAD -> 6 -> 9 -> 17 ------> 19 --> Nonesl.delete(19) # If level 2 becomes empty (HEAD.forward[2] == None), # decrease sl.level to 1
Interview comments
Interview comments
Edge cases to probe:
What happens when you insert a duplicate key? (Either reject or update value)
What if the new node’s random level exceeds the current max level? (Extend update array, increase list level)
What probability do you use for level generation? (p=0.5 is typical, gives log base 2 behavior)
How do you shrink the level after deleting the only high-level node?
Common mistakes:
Off-by-one errors in level iteration: range(self.level, -1, -1) vs range(self.level - 1, -1, -1)
Forgetting to update the skip list’s level when highest levels become empty
Using uniform distribution instead of geometric for level generation (must use coin flips)
Not initializing update array correctly for new levels above current max
Infinite loop when search doesn’t properly handle the sentinel’s key
Code solutions
Code solutions
Solution 1 is a classic implementation with a sentinel node and geometric level distribution. Solution 2 extends this with generic key-value storage and update semantics. Solution 3 adds iterator support, __len__, __contains__, and range queries. The key difference is the feature set: basic operations vs. key-value mapping vs. full collection interface with range queries.
Solution 1: Classic implementation with sentinel node
Classic implementation with sentinel node. Uses geometric distribution for level generation. Clean separation of search traversal pattern used by all operations.
"""Skip List Implementation - Classic Approach with Sentinel Node."""from dataclasses import dataclass, fieldfrom typing import Optionalimport random@dataclassclass SkipNode: """Node in a skip list with multiple forward pointers.""" key: int forward: list["SkipNode"] = field(default_factory=list)@dataclassclass SkipList: """Probabilistic sorted data structure with O(log n) expected operations.""" max_level: int = 16 p: float = 0.5 level: int = 0 header: SkipNode = field(default_factory=lambda: SkipNode(key=-1)) def __post_init__(self) -> None: self.header.forward = [None] * self.max_level def _random_level(self) -> int: """Generate random level with geometric distribution.""" lvl = 0 while random.random() < self.p and lvl < self.max_level - 1: lvl += 1 return lvl def search(self, key: int) -> bool: """Return True if key exists in the skip list.""" current = self.header for i in range(self.level, -1, -1): while current.forward[i] and current.forward[i].key < key: current = current.forward[i] current = current.forward[0] return current is not None and current.key == key def insert(self, key: int) -> None: """Insert a key into the skip list.""" update = [None] * self.max_level current = self.header for i in range(self.level, -1, -1): while current.forward[i] and current.forward[i].key < key: current = current.forward[i] update[i] = current current = current.forward[0] if current and current.key == key: return # Key already exists new_level = self._random_level() if new_level > self.level: for i in range(self.level + 1, new_level + 1): update[i] = self.header self.level = new_level new_node = SkipNode(key=key, forward=[None] * (new_level + 1)) for i in range(new_level + 1): new_node.forward[i] = update[i].forward[i] update[i].forward[i] = new_node def delete(self, key: int) -> bool: """Delete a key from the skip list. Returns True if deleted.""" update = [None] * self.max_level current = self.header for i in range(self.level, -1, -1): while current.forward[i] and current.forward[i].key < key: current = current.forward[i] update[i] = current current = current.forward[0] if current is None or current.key != key: return False for i in range(self.level + 1): if update[i].forward[i] != current: break update[i].forward[i] = current.forward[i] while self.level > 0 and self.header.forward[self.level] is None: self.level -= 1 return True
Solution 2: Generic implementation with key-value storage
Generic implementation with key-value storage. Supports updating existing keys and uses type parameters for flexibility.
"""Skip List Implementation - Generic with Value Storage."""from dataclasses import dataclass, fieldfrom typing import Generic, Optional, TypeVarimport randomK = TypeVar("K", bound=int)V = TypeVar("V")@dataclassclass Node(Generic[K, V]): """Skip list node storing key-value pair.""" key: K value: V forward: list[Optional["Node[K, V]"]] = field(default_factory=list) def __post_init__(self) -> None: if not self.forward: self.forward = [None]class SkipList(Generic[K, V]): """Skip list with key-value storage and configurable parameters.""" def __init__(self, max_level: int = 16, probability: float = 0.5) -> None: self.max_level = max_level self.probability = probability self.level = 0 self.head: Node[K, V] = Node(key=None, value=None) # type: ignore self.head.forward = [None] * max_level def _random_level(self) -> int: level = 0 while random.random() < self.probability and level < self.max_level - 1: level += 1 return level def _find_update_path(self, key: K) -> tuple[list[Node[K, V]], Optional[Node[K, V]]]: """Find path of nodes to update and the target node.""" update: list[Optional[Node[K, V]]] = [None] * self.max_level current = self.head for i in range(self.level, -1, -1): while current.forward[i] and current.forward[i].key < key: current = current.forward[i] update[i] = current return update, current.forward[0] def get(self, key: K) -> Optional[V]: """Retrieve value for key, or None if not found.""" _, target = self._find_update_path(key) if target and target.key == key: return target.value return None def insert(self, key: K, value: V) -> None: """Insert or update a key-value pair.""" update, target = self._find_update_path(key) if target and target.key == key: target.value = value # Update existing return new_level = self._random_level() if new_level > self.level: for i in range(self.level + 1, new_level + 1): update[i] = self.head self.level = new_level new_node: Node[K, V] = Node(key=key, value=value) new_node.forward = [None] * (new_level + 1) for i in range(new_level + 1): new_node.forward[i] = update[i].forward[i] update[i].forward[i] = new_node def delete(self, key: K) -> bool: """Remove key from skip list. Returns True if removed.""" update, target = self._find_update_path(key) if target is None or target.key != key: return False for i in range(self.level + 1): if update[i].forward[i] != target: break update[i].forward[i] = target.forward[i] while self.level > 0 and self.head.forward[self.level] is None: self.level -= 1 return True
Solution 3: Iterator support and range queries
Extended implementation with iterator support, __len__, __contains__, and range queries. Tracks size explicitly for O(1) length.
"""Skip List Implementation - Iterator Support and Range Queries."""from dataclasses import dataclass, fieldfrom typing import Iterator, Optionalimport random@dataclassclass SkipNode: """Node with key and multi-level forward pointers.""" key: int forward: list[Optional["SkipNode"]] = field(default_factory=list)class SkipList: """Skip list with iteration and range query support.""" MAX_LEVEL = 16 P = 0.5 def __init__(self) -> None: self.level = 0 self.head = SkipNode(key=float("-inf")) # type: ignore self.head.forward = [None] * self.MAX_LEVEL self._size = 0 def __len__(self) -> int: return self._size def __iter__(self) -> Iterator[int]: """Iterate through all keys in sorted order.""" node = self.head.forward[0] while node: yield node.key node = node.forward[0] def __contains__(self, key: int) -> bool: return self.search(key) def _random_level(self) -> int: lvl = 0 while random.random() < self.P and lvl < self.MAX_LEVEL - 1: lvl += 1 return lvl def search(self, key: int) -> bool: """Check if key exists.""" node = self.head for i in range(self.level, -1, -1): while node.forward[i] and node.forward[i].key < key: node = node.forward[i] node = node.forward[0] return node is not None and node.key == key def insert(self, key: int) -> bool: """Insert key. Returns True if inserted, False if already exists.""" update = [self.head] * self.MAX_LEVEL node = self.head for i in range(self.level, -1, -1): while node.forward[i] and node.forward[i].key < key: node = node.forward[i] update[i] = node if node.forward[0] and node.forward[0].key == key: return False new_lvl = self._random_level() if new_lvl > self.level: self.level = new_lvl new_node = SkipNode(key=key, forward=[None] * (new_lvl + 1)) for i in range(new_lvl + 1): new_node.forward[i] = update[i].forward[i] update[i].forward[i] = new_node self._size += 1 return True def delete(self, key: int) -> bool: """Delete key. Returns True if deleted.""" update = [self.head] * self.MAX_LEVEL node = self.head for i in range(self.level, -1, -1): while node.forward[i] and node.forward[i].key < key: node = node.forward[i] update[i] = node target = node.forward[0] if target is None or target.key != key: return False for i in range(len(target.forward)): update[i].forward[i] = target.forward[i] while self.level > 0 and self.head.forward[self.level] is None: self.level -= 1 self._size -= 1 return True def range_query(self, low: int, high: int) -> list[int]: """Return all keys in [low, high] inclusive.""" result = [] node = self.head for i in range(self.level, -1, -1): while node.forward[i] and node.forward[i].key < low: node = node.forward[i] node = node.forward[0] while node and node.key <= high: result.append(node.key) node = node.forward[0] return result
A request coalescer deduplicates concurrent requests for the same resource. When multiple callers request the same key simultaneously, only one actual fetch occurs - all other callers wait for and share the result. This pattern is common in caching layers, API gateways, and database connection pools to prevent thundering herd problems where many identical requests overwhelm a backend service.
The core data structures are: (1) an in-flight map tracking keys currently being fetched, mapping each key to a Future/Event that waiters can await, and (2) a cache storing results with TTL. The critical insight is lock scope: you hold the lock only to check/update the in-flight map, never during the actual fetch.
Part A: Request Coalescing
Problem: Part A
Implement a get(key) method that fetches a resource. If a request for that key is already in-flight, wait for it instead of starting a new fetch.
coalescer = RequestCoalescer(fetch_func=fetch_from_db)# Three concurrent requests for same keyresults = await asyncio.gather( coalescer.get("user:123"), coalescer.get("user:123"), coalescer.get("user:123"),)# fetch_from_db called exactly ONCE# All three get the same result# Internal state during fetch:# _in_flight = {"user:123": <Future pending>}# After completion:# _in_flight = {}
Part B: TTL Cache
Problem: Part B
Add caching with time-to-live. After a successful fetch, cache the result. Subsequent requests within TTL return the cached value without checking in-flight or fetching.
When a fetch fails, propagate the exception to ALL waiters, not just the initiator. Optionally cache errors (negative caching) to prevent repeated failing requests.
async def failing_fetch(key): raise ConnectionError("Database down")coalescer = RequestCoalescer(fetch_func=failing_fetch)# All three waiters receive the same exceptionresults = await asyncio.gather( coalescer.get("user:123"), coalescer.get("user:123"), coalescer.get("user:123"), return_exceptions=True)# results = [ConnectionError, ConnectionError, ConnectionError]# Only ONE actual fetch attempt was made
Interview comments
Interview comments
Edge cases to probe:
What happens if the first request times out but others are still waiting?
How do you handle the race between cache expiry check and in-flight check?
Should errors be cached? For how long?
What if cache_ttl=0 - does coalescing still work?
Common mistakes:
Holding the lock during the actual fetch (blocks all other keys)
Not propagating exceptions to all waiters (only first caller sees error)
Race condition: cache expires between check and in-flight registration
Forgetting to remove key from in-flight map on error
Using time.time() instead of time.monotonic() (clock can jump)
Code solutions
Code solutions
Solution 1 uses asyncio with a lock and dictionary of Futures for coordination. Solution 2 replaces Futures with asyncio.Event plus a result holder dataclass. Solution 3 provides a thread-safe synchronous version using threading primitives and ThreadPoolExecutor. These vary in their coordination mechanism (Future vs Event vs threading Future) and async vs sync model. Core techniques: double-checked locking, Future/Event-based waiting, spawning background tasks for actual fetch work.
Solution 1: asyncio with locks and Futures
Uses asyncio with a lock and dictionary of Futures. Double-checks cache after acquiring lock to handle races. Creates a Future for waiters and spawns a task to do the actual fetch.
"""Request Coalescer - Solution 1: Basic asyncio with locks and cache.Uses a simple dictionary-based approach with asyncio primitives."""import asyncioimport timefrom dataclasses import dataclass, fieldfrom typing import Any, Awaitable, Callable, Dict, OptionalFetchFunc = Callable[[str], Awaitable[Any]]@dataclassclass CacheEntry: value: Any expires_at: float error: Optional[Exception] = None@dataclassclass RequestCoalescer: fetch_func: FetchFunc cache_ttl: float = 5.0 _cache: Dict[str, CacheEntry] = field(default_factory=dict) _in_flight: Dict[str, asyncio.Future] = field(default_factory=dict) _lock: asyncio.Lock = field(default_factory=asyncio.Lock) async def get(self, key: str) -> Any: # Check cache first (no lock needed for read) if key in self._cache: entry = self._cache[key] if time.monotonic() < entry.expires_at: if entry.error: raise entry.error return entry.value del self._cache[key] async with self._lock: # Double-check cache after acquiring lock if key in self._cache: entry = self._cache[key] if time.monotonic() < entry.expires_at: if entry.error: raise entry.error return entry.value del self._cache[key] # Join existing in-flight request if key in self._in_flight: future = self._in_flight[key] else: # Start new request future = asyncio.get_event_loop().create_future() self._in_flight[key] = future asyncio.create_task(self._do_fetch(key, future)) # Wait for result outside the lock return await future async def _do_fetch(self, key: str, future: asyncio.Future) -> None: try: result = await self.fetch_func(key) self._cache[key] = CacheEntry( value=result, expires_at=time.monotonic() + self.cache_ttl ) future.set_result(result) except Exception as e: self._cache[key] = CacheEntry( value=None, expires_at=time.monotonic() + self.cache_ttl, error=e ) future.set_exception(e) finally: async with self._lock: self._in_flight.pop(key, None)
Solution 2: asyncio.Event for coordination
Uses asyncio.Event instead of Future for coordination. The PendingRequest dataclass holds both the event and the result/error, with waiters checking the result after the event is set.
"""Request Coalescer - Solution 2: Using asyncio.Event for coordination.Demonstrates an alternative pattern using Events instead of Futures."""import asyncioimport timefrom dataclasses import dataclass, fieldfrom typing import Any, Awaitable, Callable, Dict, Optional, TupleFetchFunc = Callable[[str], Awaitable[Any]]@dataclassclass PendingRequest: event: asyncio.Event result: Any = None error: Optional[Exception] = None@dataclassclass CachedResult: value: Any error: Optional[Exception] expires_at: float def is_valid(self) -> bool: return time.monotonic() < self.expires_at def get_or_raise(self) -> Any: if self.error: raise self.error return self.value@dataclassclass RequestCoalescer: fetch_func: FetchFunc cache_ttl: float = 5.0 _cache: Dict[str, CachedResult] = field(default_factory=dict) _pending: Dict[str, PendingRequest] = field(default_factory=dict) _lock: asyncio.Lock = field(default_factory=asyncio.Lock) async def get(self, key: str) -> Any: # Fast path: check cache cached = self._cache.get(key) if cached and cached.is_valid(): return cached.get_or_raise() async with self._lock: # Re-check after lock cached = self._cache.get(key) if cached and cached.is_valid(): return cached.get_or_raise() # Clean stale cache if cached: del self._cache[key] # Join or create pending request if key in self._pending: pending = self._pending[key] else: pending = PendingRequest(event=asyncio.Event()) self._pending[key] = pending asyncio.create_task(self._execute_fetch(key, pending)) # Wait outside lock await pending.event.wait() if pending.error: raise pending.error return pending.result async def _execute_fetch(self, key: str, pending: PendingRequest) -> None: try: result = await self.fetch_func(key) pending.result = result self._cache[key] = CachedResult( value=result, error=None, expires_at=time.monotonic() + self.cache_ttl ) except Exception as e: pending.error = e self._cache[key] = CachedResult( value=None, error=e, expires_at=time.monotonic() + self.cache_ttl ) finally: pending.event.set() async with self._lock: self._pending.pop(key, None)
Solution 3: Thread-safe synchronous version
Thread-safe synchronous version using threading primitives and ThreadPoolExecutor. Suitable for non-async codebases. Uses concurrent.futures.Future for coordination between threads.
"""Request Coalescer - Solution 3: Thread-safe version using threading primitives.Suitable for synchronous code or mixed async/sync environments."""import threadingimport timefrom concurrent.futures import Future, ThreadPoolExecutorfrom dataclasses import dataclass, fieldfrom typing import Any, Callable, Dict, OptionalFetchFunc = Callable[[str], Any]@dataclassclass CacheEntry: value: Any error: Optional[Exception] expires_at: float @property def is_expired(self) -> bool: return time.monotonic() >= self.expires_at@dataclassclass RequestCoalescer: fetch_func: FetchFunc cache_ttl: float = 5.0 max_workers: int = 4 _cache: Dict[str, CacheEntry] = field(default_factory=dict) _in_flight: Dict[str, Future] = field(default_factory=dict) _lock: threading.Lock = field(default_factory=threading.Lock) _executor: ThreadPoolExecutor = field(init=False) def __post_init__(self): self._executor = ThreadPoolExecutor(max_workers=self.max_workers) def get(self, key: str) -> Any: # Check cache (acquire lock for thread safety) with self._lock: if key in self._cache: entry = self._cache[key] if not entry.is_expired: if entry.error: raise entry.error return entry.value del self._cache[key] # Join existing request or start new one if key in self._in_flight: future = self._in_flight[key] else: future = Future() self._in_flight[key] = future self._executor.submit(self._do_fetch, key, future) # Wait for result (outside lock to allow concurrency) return future.result() def _do_fetch(self, key: str, future: Future) -> None: try: result = self.fetch_func(key) with self._lock: self._cache[key] = CacheEntry( value=result, error=None, expires_at=time.monotonic() + self.cache_ttl ) future.set_result(result) except Exception as e: with self._lock: self._cache[key] = CacheEntry( value=None, error=e, expires_at=time.monotonic() + self.cache_ttl ) future.set_exception(e) finally: with self._lock: self._in_flight.pop(key, None) def shutdown(self): self._executor.shutdown(wait=True)
Question 3 - Lock-Free Queue
Difficulty: 10 / 10
Approximate lines of code: 80 LoC
Tags: concurrency, data-structures
Description
Concurrency Cheat Sheet
This problem involves concurrency. Here is a cheat sheet of relevant Python syntax. You may not need all of these; this list is not exhaustive or suggestive for the problem.
# threading.Threadt = threading.Thread(target=func, args=(arg1, arg2))t.start()t.join()# threading.Locklock = threading.Lock()lock.acquire()lock.release()with lock: # critical section# threading.Conditioncond = threading.Condition()with cond: cond.wait() # release lock, wait for notify cond.notify() # wake one waiter cond.notify_all() # wake all waiters# threading.Semaphoresem = threading.Semaphore(value=3)sem.acquire()sem.release()# threading.Eventevent = threading.Event()event.set() # signalevent.clear() # resetevent.wait() # block until setevent.is_set() # check without blocking# queue.Queue (thread-safe)from queue import Queueq = Queue()q.put(item)item = q.get() # blocksitem = q.get_nowait() # raises Empty if none# Atomic operations concept# Python's GIL makes single bytecode ops atomic# For true atomics, use: threading.Lock or ctypes/cffi# CAS pattern: if current == expected: current = new
A lock-free queue allows multiple threads to enqueue and dequeue concurrently without using traditional locks (mutexes). Instead of blocking threads, it uses atomic compare-and-swap (CAS) operations: “if the value is still X, change it to Y; otherwise retry.” This enables progress guarantees - at least one thread makes progress even if others stall.
The classic implementation is the Michael-Scott queue: a linked list with head and tail pointers. A sentinel (dummy) node simplifies empty-queue handling. The key insight is that CAS operations can fail when another thread modifies the data, so operations run in retry loops. The algorithm also requires “helping” - if one thread sees the tail pointer lagging behind, it advances the tail before proceeding.
Part A: Basic Queue
Problem: Part A
Implement a thread-safe queue with enqueue(value) and dequeue() using CAS. Use a linked list with a sentinel node (head and tail both point to it initially).
q = LockFreeQueue()# Initial state: head -> [sentinel] <- tail# next=Noneq.enqueue(1)# State: head -> [sentinel] -> [1] <- tail# next=[1] next=Noneq.enqueue(2)# State: head -> [sentinel] -> [1] -> [2] <- tailval = q.dequeue() # Returns 1# State: head -> [1] -> [2] <- tail# (old sentinel removed, [1] is new sentinel, value returned was [1]'s value)
Key operations:
Enqueue: CAS the tail node’s next pointer from None to new node, then CAS tail to point to new node
Dequeue: CAS head pointer from sentinel to sentinel.next, return the value from the new sentinel
Part B: The ABA Problem
Problem: Part B
A subtle bug: if pointer A changes to B then back to A, a CAS sees “still A” but the underlying data changed. Example: Thread 1 reads head=NodeA, gets preempted. Thread 2 dequeues NodeA, dequeues NodeB, re-enqueues a recycled NodeA. Thread 1 resumes, CAS succeeds (head is “still” NodeA), but the queue structure is corrupted.
Solve with tagged pointers (version stamps): pair each pointer with an incrementing counter. CAS checks both pointer AND stamp - even if pointer recycles, the stamp won’t match.
@dataclassclass StampedRef: ref: Node stamp: int # Increments on every modification# CAS now checks: (ref == expected_ref AND stamp == expected_stamp)# Even if ref is recycled, stamp won't match the old value
Part C: Memory Reclamation
Problem: Part C
In a real system, you can’t free dequeued nodes immediately - another thread might still be reading them. Options:
Hazard pointers: Each thread publishes pointers it’s currently accessing; don’t free those
Epoch-based reclamation: Track which “epoch” each thread is in; only free nodes from old epochs when all threads have advanced
Reference counting: Atomic reference counts (but adds overhead to every operation)
# Hazard pointer sketch:class ThreadLocal: hazard_pointer: Node = None # "I'm currently looking at this node"def dequeue(): while True: head = self.head # Publish hazard pointer BEFORE dereferencing thread_local.hazard_pointer = head if head != self.head: continue # head changed, retry # Now safe to read head.next ...
Interview comments
Interview comments
Edge cases to probe:
What happens if dequeue is called on an empty queue? (Return None, don’t block)
Why do you need a sentinel node? (Simplifies empty case - head and tail always point to valid nodes)
What if enqueue CAS succeeds but advancing tail fails? (Another thread must help advance it)
How do you know if the queue is empty? (head == tail AND head.next is None)
Common mistakes:
Forgetting the sentinel node (leads to null pointer edge cases)
Missing the “helping” mechanism in enqueue (if tail.next isn’t None, advance tail first)
Wrong CAS order: must CAS next pointer BEFORE advancing tail
Using raw pointers without version stamps (ABA vulnerability)
Not handling the case where head.next is null in dequeue (empty queue)
Code solutions
Code solutions
Solution 1 implements the classic Michael-Scott queue using a linked list with stamped references to solve the ABA problem. Solution 2 uses an array-based bounded circular buffer with sequence numbers for ABA prevention, offering better cache locality. Solution 3 is a wait-free SPSC (Single-Producer Single-Consumer) queue that eliminates CAS retry loops by restricting to one producer and one consumer. These vary in their concurrency model (MPMC vs SPSC) and data structure (linked list vs array). Core techniques: compare-and-swap (CAS), stamped/tagged pointers, sequence numbers, memory barriers.
Solution 1: Michael-Scott queue with stamped references
Michael-Scott queue with stamped references to solve ABA. Uses a lock to simulate atomic CAS operations (real implementations use hardware atomics). Includes the full “helping” mechanism where enqueue advances a lagging tail pointer.
"""Lock-Free Queue - Solution 1: Michael-Scott Queue with Tagged PointersUses stamped references to solve the ABA problem."""from __future__ import annotationsfrom dataclasses import dataclassfrom typing import Generic, TypeVar, Optionalimport threadingT = TypeVar("T")@dataclassclass StampedRef(Generic[T]): """Stamped reference to solve ABA problem.""" ref: Optional[Node[T]] stamp: int = 0@dataclassclass Node(Generic[T]): value: Optional[T] next: StampedRef[T] = None # type: ignore def __post_init__(self) -> None: if self.next is None: self.next = StampedRef(None, 0)class LockFreeQueue(Generic[T]): """Lock-free queue using CAS with stamped references.""" def __init__(self) -> None: sentinel = Node[T](None) self._head = StampedRef(sentinel, 0) self._tail = StampedRef(sentinel, 0) self._lock = threading.Lock() # Simulates CAS atomicity def _cas(self, target: StampedRef[T], expected_ref: Optional[Node[T]], expected_stamp: int, new_ref: Optional[Node[T]], new_stamp: int) -> bool: """Simulated compare-and-swap with stamp check.""" with self._lock: if target.ref is expected_ref and target.stamp == expected_stamp: target.ref = new_ref target.stamp = new_stamp return True return False def enqueue(self, value: T) -> None: """Add item to the back of the queue.""" new_node = Node(value) while True: tail = self._tail tail_node = tail.ref next_ref = tail_node.next if tail.ref is self._tail.ref and tail.stamp == self._tail.stamp: if next_ref.ref is None: if self._cas(tail_node.next, None, next_ref.stamp, new_node, next_ref.stamp + 1): self._cas(self._tail, tail_node, tail.stamp, new_node, tail.stamp + 1) return else: self._cas(self._tail, tail_node, tail.stamp, next_ref.ref, tail.stamp + 1) def dequeue(self) -> Optional[T]: """Remove and return item from front. Returns None if empty.""" while True: head = self._head tail = self._tail head_node = head.ref next_ref = head_node.next if head.ref is self._head.ref and head.stamp == self._head.stamp: if head_node is tail.ref: if next_ref.ref is None: return None self._cas(self._tail, tail.ref, tail.stamp, next_ref.ref, tail.stamp + 1) else: value = next_ref.ref.value if self._cas(self._head, head_node, head.stamp, next_ref.ref, head.stamp + 1): return value
Solution 2: Array-based bounded queue
Array-based bounded queue using a circular buffer with atomic indices. Each slot has a sequence number for ABA prevention. The sequence discipline: a slot is ready for enqueue when seq tail, ready for dequeue when seq head + 1. More cache-friendly than linked lists.
"""Lock-Free Queue - Solution 2: Array-based Bounded Queue with CASFixed-size circular buffer using atomic indices."""from __future__ import annotationsfrom dataclasses import dataclass, fieldfrom typing import Generic, TypeVar, Optionalimport threadingT = TypeVar("T")@dataclassclass AtomicInt: """Simulated atomic integer with CAS operation.""" _value: int = 0 _lock: threading.Lock = field(default_factory=threading.Lock) def get(self) -> int: with self._lock: return self._value def cas(self, expected: int, new: int) -> bool: with self._lock: if self._value == expected: self._value = new return True return False@dataclassclass Slot(Generic[T]): """Slot with sequence number for ABA prevention.""" sequence: AtomicInt value: Optional[T] = Noneclass BoundedLockFreeQueue(Generic[T]): """Bounded lock-free queue using array with sequences.""" def __init__(self, capacity: int = 64) -> None: self._capacity = capacity self._mask = capacity - 1 assert capacity > 0 and (capacity & self._mask) == 0, "Must be power of 2" self._slots: list[Slot[T]] = [ Slot(AtomicInt(i)) for i in range(capacity) ] self._head = AtomicInt(0) self._tail = AtomicInt(0) def enqueue(self, value: T) -> bool: """Add item. Returns False if queue is full.""" while True: tail = self._tail.get() slot = self._slots[tail & self._mask] seq = slot.sequence.get() diff = seq - tail if diff == 0: if self._tail.cas(tail, tail + 1): slot.value = value slot.sequence.cas(seq, seq + 1) return True elif diff < 0: return False # Queue is full def dequeue(self) -> Optional[T]: """Remove and return item. Returns None if empty.""" while True: head = self._head.get() slot = self._slots[head & self._mask] seq = slot.sequence.get() diff = seq - (head + 1) if diff == 0: if self._head.cas(head, head + 1): value = slot.value slot.sequence.cas(seq, seq + self._capacity - 1) return value elif diff < 0: return None # Queue is empty
Solution 3: Wait-free SPSC queue
Wait-free Single-Producer Single-Consumer (SPSC) queue. When only one thread enqueues and one dequeues, the algorithm simplifies dramatically - no CAS loops needed, just atomic loads and stores with cached indices for performance. Demonstrates the power of restricting the concurrency model.
"""Lock-Free Queue - Solution 3: Wait-Free SPSC QueueSingle-Producer Single-Consumer queue - simpler but restricted use case.Demonstrates progression from MPMC to SPSC optimization."""from __future__ import annotationsfrom dataclasses import dataclass, fieldfrom typing import Generic, TypeVar, Optionalimport threadingT = TypeVar("T")@dataclassclass AtomicRef(Generic[T]): """Simulated atomic reference for memory ordering.""" _value: T _lock: threading.Lock = field(default_factory=threading.Lock) def load(self) -> T: with self._lock: return self._value def store(self, value: T) -> None: with self._lock: self._value = valueclass SPSCQueue(Generic[T]): """Wait-free single-producer single-consumer queue.""" def __init__(self, capacity: int = 64) -> None: assert capacity > 0 and (capacity & (capacity - 1)) == 0 self._capacity = capacity self._mask = capacity - 1 self._buffer: list[Optional[T]] = [None] * capacity self._head = AtomicRef(0) # Consumer reads this self._tail = AtomicRef(0) # Producer writes this self._cached_head = 0 # Producer's cached view self._cached_tail = 0 # Consumer's cached view def enqueue(self, value: T) -> bool: """Producer adds item. Returns False if full.""" tail = self._tail.load() next_tail = (tail + 1) & self._mask if next_tail == self._cached_head: self._cached_head = self._head.load() if next_tail == self._cached_head: return False self._buffer[tail] = value self._tail.store(next_tail) return True def dequeue(self) -> Optional[T]: """Consumer removes item. Returns None if empty.""" head = self._head.load() if head == self._cached_tail: self._cached_tail = self._tail.load() if head == self._cached_tail: return None value = self._buffer[head] self._buffer[head] = None self._head.store((head + 1) & self._mask) return value def is_empty(self) -> bool: return self._head.load() == self._tail.load()
Question 4 - Count-Min Sketch
Difficulty: 8 / 10
Approximate lines of code: 70 LoC
Tags: probabilistic, data-structures
Description
A Count-Min Sketch is a probabilistic data structure for estimating the frequency of items in a data stream using sublinear space. Unlike a hash map that stores exact counts, a CMS uses a 2D array of counters with multiple hash functions. Each item hashes to one position per row, and all those positions get incremented. To query, you take the minimum across all rows - this gives an estimate that may overcount (due to hash collisions) but never undercounts.
The structure uses O(w * d) space where w is width and d is depth (number of hash functions). With proper sizing, you get: with probability >= 1 - delta, the estimate is at most true_count + epsilon * N, where N is total items seen. The formulas are w = ceil(e/epsilon) and d = ceil(ln(1/delta)).
Part A: Basic Structure
Problem: Part A
Implement a Count-Min Sketch with increment(item, count=1) and query(item) methods.
cms = CountMinSketch(width=100, depth=5)# Increment countscms.increment("apple", 10)cms.increment("banana", 5)cms.increment("apple", 3)# Query returns minimum across all rowscms.query("apple") # >= 13 (exact or overestimate)cms.query("banana") # >= 5cms.query("cherry") # >= 0 (never seen, but may have collisions)
Internal state (width=5, depth=3 for illustration):
After adding "apple" 13 times and "banana" 5 times:
Row 0: [0, 13, 5, 0, 0] <- "apple" hashes to col 1, "banana" to col 2
Row 1: [5, 0, 13, 0, 0] <- different hash: apple->2, banana->0
Row 2: [0, 13, 0, 5, 0] <- apple->1, banana->3
query("apple") = min(13, 13, 13) = 13
query("banana") = min(5, 5, 5) = 5
Part B: Optimal Parameters
Problem: Part B
Implement a method to calculate optimal width and depth given desired error bounds.
Add support for merging two sketches (useful for distributed counting) and explain the no-undercount guarantee.
cms1 = CountMinSketch(width=100, depth=5)cms2 = CountMinSketch(width=100, depth=5)cms1.add("apple", 1000)cms2.add("apple", 500)cms1.merge(cms2) # Element-wise max of all counterscms1.query("apple") # >= 1000 (from cms1's counters)
Why no undercount? Because query() takes the minimum. Even if some positions have extra counts from collisions, the true item was incremented at all positions, so at least one position has only that item’s count (in the ideal case).
Interview comments
Interview comments
Edge cases to probe:
What happens when querying an item never added? (Returns 0 or collision noise)
Why use min() instead of max() or average?
How would you handle deletions? (You cannot - counts only go up)
What’s the space complexity vs. a hash map?
Common mistakes:
Using max() instead of min() for query (defeats the purpose)
Using a single hash function (no way to reduce collision error)
Using Python’s built-in hash() which is non-deterministic across runs
Forgetting that the sketch can only overestimate, never underestimate
Code solutions
Code solutions
Solutions Overview
Solution 1 is a classic implementation using MD5 hashing with seeded rows. Solution 2 uses a universal hash family h(x) = ((a*x + b) mod p) mod w for better theoretical collision bounds. Solution 3 implements the conservative update variant that reduces overestimation by only incrementing counters up to min + count. These vary in their hash function choice and update strategy. Core techniques: multiple hash functions, 2D counter array, min-query for estimates.
Solution 1: Classic MD5 hashing
Classic implementation using MD5 hashing with seed per row. Tracks total count for error bound calculation. Simple and readable.
"""Count-Min Sketch - Solution 1: Classic ImplementationA probabilistic data structure for frequency estimation.Uses multiple hash functions to provide approximate counts withguaranteed no underestimation."""from dataclasses import dataclass, fieldfrom typing import Listimport hashlib@dataclassclass CountMinSketch: """ Count-Min Sketch for approximate frequency counting. Args: width: Number of counters per row (controls accuracy) depth: Number of hash functions/rows (controls confidence) """ width: int depth: int table: List[List[int]] = field(init=False) total_count: int = field(default=0, init=False) def __post_init__(self) -> None: self.table = [[0] * self.width for _ in range(self.depth)] def _hash(self, item: str, seed: int) -> int: """Generate a hash for the item with a given seed.""" data = f"{seed}:{item}".encode("utf-8") digest = hashlib.md5(data).hexdigest() return int(digest, 16) % self.width def increment(self, item: str, count: int = 1) -> None: """Add an item to the sketch with optional count.""" for row in range(self.depth): col = self._hash(item, row) self.table[row][col] += count self.total_count += count def query(self, item: str) -> int: """ Estimate the count of an item. Returns the minimum across all hash positions. May overestimate but never underestimates. """ min_count = float("inf") for row in range(self.depth): col = self._hash(item, row) min_count = min(min_count, self.table[row][col]) return int(min_count) def error_bound(self, epsilon: float = None) -> float: """ Expected error is at most (epsilon * total_count). epsilon = e / width, where e is Euler's number. """ import math eps = math.e / self.width return eps * self.total_count
Solution 2: Universal hash family
Uses universal hash family h(x) = ((a*x + b) mod p) mod w for better theoretical guarantees. Provides detailed error analysis with epsilon, delta, and confidence calculations.
"""Count-Min Sketch - Solution 2: Using Universal HashingThis version uses a family of universal hash functions (ax + b mod p mod w)for better theoretical guarantees on collision probability."""from dataclasses import dataclass, fieldfrom typing import List, Hashableimport random@dataclassclass UniversalHashFamily: """Universal hash function family: h(x) = ((a*x + b) mod p) mod width.""" width: int depth: int prime: int = field(default=2**61 - 1, init=False) # Large Mersenne prime a_values: List[int] = field(init=False) b_values: List[int] = field(init=False) def __post_init__(self) -> None: random.seed(42) # Deterministic for reproducibility self.a_values = [random.randint(1, self.prime - 1) for _ in range(self.depth)] self.b_values = [random.randint(0, self.prime - 1) for _ in range(self.depth)] def hash(self, item: Hashable, row: int) -> int: x = hash(item) return ((self.a_values[row] * x + self.b_values[row]) % self.prime) % self.width@dataclassclass CountMinSketch: """ Count-Min Sketch with universal hashing. Error bounds: - With probability >= 1 - delta, estimate <= true_count + epsilon * N - epsilon = e / width - delta = e^(-depth) """ width: int depth: int hash_family: UniversalHashFamily = field(init=False) table: List[List[int]] = field(init=False) _total: int = field(default=0, init=False) def __post_init__(self) -> None: self.table = [[0] * self.width for _ in range(self.depth)] self.hash_family = UniversalHashFamily(self.width, self.depth) def add(self, item: Hashable, count: int = 1) -> None: """Increment the count for an item.""" for row in range(self.depth): col = self.hash_family.hash(item, row) self.table[row][col] += count self._total += count def estimate(self, item: Hashable) -> int: """Return the estimated count (minimum across all rows).""" return min( self.table[row][self.hash_family.hash(item, row)] for row in range(self.depth) ) @property def total_count(self) -> int: return self._total def theoretical_error(self) -> dict: """Return theoretical error analysis.""" import math epsilon = math.e / self.width delta = math.exp(-self.depth) return { "epsilon": epsilon, "delta": delta, "max_overestimate": epsilon * self._total, "confidence": 1 - delta, }
Solution 3: Conservative update variant
Conservative update variant that reduces overestimation. Only increments counters up to min + count rather than blindly adding to all. Includes merge and optimal parameter calculation.
"""Count-Min Sketch - Solution 3: Conservative Update VariantConservative update reduces overestimation by only incrementingcounters up to the current minimum + count, rather than always adding."""from dataclasses import dataclass, fieldfrom typing import List, Tupleimport hashlib@dataclassclass ConservativeCountMinSketch: """ Count-Min Sketch with conservative update optimization. Conservative update: instead of blindly incrementing all counters, only increment up to max(current_estimate, new_count). This reduces overestimation while maintaining the no-undercount guarantee. """ width: int depth: int seed: int = 0 table: List[List[int]] = field(init=False) _n: int = field(default=0, init=False) def __post_init__(self) -> None: self.table = [[0] * self.width for _ in range(self.depth)] def _hash(self, item: str, row: int) -> int: """Generate hash for item at given row.""" data = f"{self.seed + row}:{item}".encode("utf-8") digest = hashlib.sha256(data).hexdigest() return int(digest[:16], 16) % self.width def _get_positions(self, item: str) -> List[Tuple[int, int]]: """Get (row, col) positions for an item.""" return [(row, self._hash(item, row)) for row in range(self.depth)] def add(self, item: str, count: int = 1) -> None: """Add item with conservative update.""" positions = self._get_positions(item) # Get current minimum estimate current_counts = [self.table[r][c] for r, c in positions] min_count = min(current_counts) # Conservative update: set all positions to max of their current value # and (min_count + count) new_target = min_count + count for r, c in positions: self.table[r][c] = max(self.table[r][c], new_target) self._n += count def query(self, item: str) -> int: """Estimate count for item.""" positions = self._get_positions(item) return min(self.table[r][c] for r, c in positions) def merge(self, other: "ConservativeCountMinSketch") -> None: """Merge another sketch into this one (element-wise max).""" assert self.width == other.width and self.depth == other.depth for r in range(self.depth): for c in range(self.width): self.table[r][c] = max(self.table[r][c], other.table[r][c]) self._n += other._n @staticmethod def optimal_params(epsilon: float, delta: float) -> Tuple[int, int]: """Calculate optimal width and depth for desired error bounds.""" import math width = int(math.ceil(math.e / epsilon)) depth = int(math.ceil(math.log(1 / delta))) return width, depth
Question 5 - Double-Entry Ledger
Difficulty: 4 / 10
Approximate lines of code: 90 LoC
Tags: storage
Description
Double-entry bookkeeping is the foundation of modern accounting. Every transaction affects at least two accounts: one is debited, one is credited, and the total debits must equal total credits. This invariant makes errors detectable - if your books don’t balance, something is wrong. Assets and expenses have “debit normal” balances (debits increase them), while liabilities, equity, and revenue have “credit normal” balances (credits increase them).
The core data structures are: accounts (with names and types), transactions (with descriptions and multiple entries), and a ledger that validates and stores everything. A trial balance report sums all account balances to verify the books still balance. Use Decimal for money - floating point arithmetic will eventually produce rounding errors that break the balance invariant.
Part A: Accounts and Transactions
Problem: Part A
Implement account creation and transaction posting. Each transaction has multiple entries, and debits must equal credits before the transaction can be recorded.
Generate a trial balance report that lists all accounts with their debit and credit balances. Total debits must equal total credits.
ledger.post_entry("Cash", "Revenue", 500.0, "Cash sale")ledger.post_entry("Accounts Receivable", "Revenue", 300.0, "Credit sale")ledger.post_entry("Rent Expense", "Cash", 200.0, "Paid rent")balances, is_balanced = ledger.trial_balance()# balances:# Cash: DR 500, CR 200 -> Net DR 300# Revenue: DR 0, CR 800 -> Net CR 800# Accounts Receivable: DR 300, CR 0 -> Net DR 300# Rent Expense: DR 200, CR 0 -> Net DR 200## Total DR: 300 + 300 + 200 = 800# Total CR: 800# is_balanced = True
Part C: Multi-Leg Transactions
Problem: Part C
Support transactions with multiple debits and credits (not just pairs). Example: selling inventory involves four entries - debit Cash (receive payment), credit Revenue (recognize sale), debit COGS (expense the cost), credit Inventory (reduce stock).
gl = GeneralLedger()for acct in ["Cash", "Inventory", "Revenue", "COGS"]: gl.open_account(acct)sale = Transaction("Sold inventory")sale.add_debit("Cash", Decimal("150.00")) # Receive $150sale.add_credit("Revenue", Decimal("150.00")) # Recognize revenuesale.add_debit("COGS", Decimal("80.00")) # Cost of goods soldsale.add_credit("Inventory", Decimal("80.00")) # Reduce inventorygl.commit(sale)# Total debits: 150 + 80 = 230# Total credits: 150 + 80 = 230# Transaction is balanced and committed
Interview comments
Interview comments
Edge cases to probe:
What happens if you try to post to a non-existent account?
How do you handle negative amounts in entries?
Can the same account appear multiple times in one transaction?
How do you display credit-normal accounts (negative debit balance)?
Common mistakes:
Using float instead of Decimal for monetary amounts
Confusing debit/credit signs (debits are not always positive)
Not validating that accounts exist before posting
Forgetting that credit-normal accounts show as “negative” debit balances
Code solutions
Code solutions
Solution 1 uses a simple dictionary-based approach where balances are calculated on-demand by iterating through all transactions (O(n) per query). Solution 2 takes an account-centric design where each account maintains running debit/credit totals for O(1) balance queries. Solution 3 supports multi-leg transactions with arbitrary numbers of debits and credits, using Decimal for precise money handling. These vary in their approach to balance tracking and transaction complexity: Solution 1 prioritizes simplicity, Solution 2 optimizes for read performance, and Solution 3 handles real-world accounting scenarios. Core techniques: dictionary storage, running balance tracking, Decimal arithmetic with ROUND_HALF_UP, transaction validation.
Solution 1: Simple Dictionary-Based Approach
Simple dictionary-based approach. Stores accounts as a dict mapping name to type, transactions as a list. Calculates balances by iterating through all transactions - O(n) per balance query. Clean and correct but not optimized for frequent queries.
"""Double-Entry Ledger - Solution 1: Simple Dictionary-Based ApproachUses dictionaries to track accounts and a list for transactions."""from dataclasses import dataclass, fieldfrom typing import Dict, List, Tuplefrom enum import Enumclass AccountType(Enum): ASSET = "asset" LIABILITY = "liability" EQUITY = "equity" REVENUE = "revenue" EXPENSE = "expense"@dataclassclass Entry: account: str amount: float is_debit: bool@dataclassclass Transaction: description: str entries: List[Entry] = field(default_factory=list) def is_balanced(self) -> bool: debits = sum(e.amount for e in self.entries if e.is_debit) credits = sum(e.amount for e in self.entries if not e.is_debit) return abs(debits - credits) < 0.001class Ledger: def __init__(self) -> None: self.accounts: Dict[str, AccountType] = {} self.transactions: List[Transaction] = [] def create_account(self, name: str, account_type: AccountType) -> None: self.accounts[name] = account_type def record_transaction(self, transaction: Transaction) -> None: if not transaction.is_balanced(): raise ValueError("Transaction must balance: debits must equal credits") for entry in transaction.entries: if entry.account not in self.accounts: raise ValueError(f"Account '{entry.account}' does not exist") self.transactions.append(transaction) def get_balance(self, account_name: str) -> float: if account_name not in self.accounts: raise ValueError(f"Account '{account_name}' does not exist") balance = 0.0 for txn in self.transactions: for entry in txn.entries: if entry.account == account_name: if entry.is_debit: balance += entry.amount else: balance -= entry.amount return balance def trial_balance(self) -> Tuple[Dict[str, float], float, float]: balances: Dict[str, float] = {} total_debits, total_credits = 0.0, 0.0 for account in self.accounts: bal = self.get_balance(account) balances[account] = bal if bal >= 0: total_debits += bal else: total_credits += abs(bal) return balances, total_debits, total_credits
Solution 2: Account-Centric with Balance Tracking
Account-centric design with running balance tracking. Each Account object maintains separate debit and credit totals, enabling O(1) balance queries. Transactions are simpler (just debit account, credit account, amount) since each entry affects exactly two accounts.
Solution 3: Multi-Leg Transactions with Validation
Multi-leg transactions with Decimal precision. Uses an enum for debit/credit sides and immutable Leg dataclass. Supports complex transactions with arbitrary numbers of debits and credits. Includes proper money rounding with ROUND_HALF_UP.
"""Double-Entry Ledger - Solution 3: Multi-Leg Transactions with ValidationSupports complex transactions with multiple debits and credits."""from dataclasses import dataclass, fieldfrom typing import Dict, List, Optionalfrom decimal import Decimal, ROUND_HALF_UPfrom enum import Enum, autoclass Side(Enum): DEBIT = auto() CREDIT = auto()@dataclass(frozen=True)class Leg: account: str amount: Decimal side: Side@dataclassclass Transaction: memo: str legs: List[Leg] = field(default_factory=list) def add_debit(self, account: str, amount: Decimal) -> "Transaction": self.legs.append(Leg(account, amount, Side.DEBIT)) return self def add_credit(self, account: str, amount: Decimal) -> "Transaction": self.legs.append(Leg(account, amount, Side.CREDIT)) return self def total_debits(self) -> Decimal: return sum((leg.amount for leg in self.legs if leg.side == Side.DEBIT), Decimal(0)) def total_credits(self) -> Decimal: return sum((leg.amount for leg in self.legs if leg.side == Side.CREDIT), Decimal(0)) def is_balanced(self) -> bool: return self.total_debits() == self.total_credits()class GeneralLedger: def __init__(self) -> None: self._accounts: Dict[str, List[Leg]] = {} self._transactions: List[Transaction] = [] def open_account(self, name: str) -> None: if name in self._accounts: raise ValueError(f"Account '{name}' already exists") self._accounts[name] = [] def commit(self, txn: Transaction) -> None: if not txn.is_balanced(): diff = txn.total_debits() - txn.total_credits() raise ValueError(f"Unbalanced transaction: off by {diff}") for leg in txn.legs: if leg.account not in self._accounts: raise KeyError(f"Unknown account: {leg.account}") for leg in txn.legs: self._accounts[leg.account].append(leg) self._transactions.append(txn) def balance(self, account: str) -> Decimal: if account not in self._accounts: raise KeyError(f"Unknown account: {account}") total = Decimal(0) for leg in self._accounts[account]: if leg.side == Side.DEBIT: total += leg.amount else: total -= leg.amount return total.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) def trial_balance_report(self) -> str: lines = ["TRIAL BALANCE", "=" * 50] total_dr, total_cr = Decimal(0), Decimal(0) for acct in sorted(self._accounts.keys()): bal = self.balance(acct) dr = bal if bal > 0 else Decimal(0) cr = abs(bal) if bal < 0 else Decimal(0) total_dr += dr total_cr += cr lines.append(f"{acct:<25} DR:{dr:>10.2f} CR:{cr:>10.2f}") lines.append("=" * 50) lines.append(f"{'TOTALS':<25} DR:{total_dr:>10.2f} CR:{total_cr:>10.2f}") lines.append(f"Balanced: {total_dr == total_cr}") return "\n".join(lines)
Question 6 - Task DAG Executor
Difficulty: 9 / 10
Approximate lines of code: 70 LoC
Tags: scheduling, concurrency
Description
Concurrency Cheat Sheet
This problem involves concurrency. Here is a cheat sheet of relevant Python syntax. You may not need all of these; this list is not exhaustive or suggestive for the problem.
# threading.Threadt = threading.Thread(target=fn, args=(arg1,))t.start()t.join()# threading.Locklock = threading.Lock()lock.acquire()lock.release()with lock: # critical section# threading.Conditioncond = threading.Condition()with cond: cond.wait() # release lock, wait, reacquire cond.notify() # wake one waiting thread cond.notify_all() # wake all waiting threads# threading.Semaphoresem = threading.Semaphore(value=3)sem.acquire()sem.release()# threading.Eventevent = threading.Event()event.set() # set flag to Trueevent.clear() # set flag to Falseevent.wait() # block until flag is Trueevent.is_set() # check flag# concurrent.futures.ThreadPoolExecutorwith ThreadPoolExecutor(max_workers=4) as executor: future = executor.submit(fn, arg1, arg2) result = future.result() futures = [executor.submit(fn, x) for x in items] for f in as_completed(futures): result = f.result()# queue.Queue (thread-safe)q = queue.Queue()q.put(item)item = q.get() # blocks until availableq.task_done()q.join() # block until all tasks done
A Task DAG (Directed Acyclic Graph) Executor is the core of build systems like Make, Bazel, and Gradle. Tasks have dependencies: Task B depends on Task A means A must complete before B can start. The executor must find a valid execution order (topological sort) or detect if the dependencies form a cycle (impossible to execute). This is fundamentally a topological sorting problem.
Two main approaches exist: Kahn’s algorithm (BFS-based) counts incoming edges and processes nodes with zero in-degree; DFS-based post-order traversal visits dependencies recursively and appends nodes after all dependencies are visited. Both run in O(V + E) time. Internal state typically includes a task map (name → task), an adjacency list (task → list of dependents), and in-degree counts or visit state.
Part A: Basic Execution Order
Problem: Part A
Given tasks with dependencies, return a valid execution order. If there’s a cycle, raise an error.
Instead of a flat list, return batches of tasks that can run in parallel. First batch has no dependencies. Second batch only depends on first batch. Etc.
tasks = [ Task("A", fn_a, dependencies=[]), Task("B", fn_b, dependencies=[]), Task("C", fn_c, dependencies=["A"]), Task("D", fn_d, dependencies=["A", "B"]), Task("E", fn_e, dependencies=["C", "D"]),]get_parallel_batches(tasks)# Returns: [["A", "B"], ["C", "D"], ["E"]]# Batch 1: A and B have no deps, run in parallel# Batch 2: C and D only depend on batch 1# Batch 3: E depends on batch 2
Part C: Actual Parallel Execution
Problem: Part C
Execute tasks concurrently using a thread pool. Each task waits for its dependencies to complete before running. Handle failures gracefully - if a task fails, tasks that depend on it should be skipped or marked as failed.
def slow_compile(): time.sleep(1.0) return "compiled"def slow_test(): time.sleep(1.0) return "tested"tasks = [ Task("compile", slow_compile, dependencies=[]), Task("test", slow_test, dependencies=["compile"]), Task("lint", slow_lint, dependencies=[]), # independent]# Sequential would take 3+ seconds# Parallel: compile and lint run together (1s), then test (1s) = 2s totalstart = time.time()results = execute_tasks(tasks, max_workers=4)elapsed = time.time() - start # ~2 seconds, not 3
Interview comments
Interview comments
Edge cases to probe:
What if a dependency doesn’t exist in the task list?
What about self-loops (task depends on itself)?
What if the graph has disconnected components (independent task groups)?
How do you handle a task that fails mid-execution?
Common mistakes:
Only starting DFS from one node (misses disconnected components)
Using simple visited set for cycle detection (only detects revisits, not back-edges - need three-color WHITE/GRAY/BLACK)
Not validating that dependencies exist before execution
Forgetting to handle the “no tasks” case
In parallel version: not waiting for all dependencies before starting a task
In parallel version: submitting tasks before all futures are registered (race condition)
Code solutions
Code solutions
Solution 1 uses Kahn’s algorithm (BFS-based), counting in-degrees and processing tasks with zero in-degree in batches. Solution 2 uses DFS with three-color marking (UNVISITED/VISITING/VISITED) for cycle detection and handles disconnected components. Solution 3 adds concurrent execution with ThreadPoolExecutor, where each task waits on its dependency futures before executing. The key difference is sequential vs parallel execution and BFS vs DFS traversal. Core techniques: topological sort, Kahn’s algorithm, DFS cycle detection, thread pools.
Solution 1: Kahn’s algorithm (BFS-based)
Kahn’s algorithm (BFS-based). Builds in-degree map and processes tasks with zero in-degree. Each iteration represents a parallel batch. Detects cycles when not all tasks are executed.
"""Solution 1: Kahn's Algorithm (BFS-based topological sort)Uses in-degree counting and a queue to process tasks level by level."""from dataclasses import dataclass, fieldfrom typing import Callable, Anyfrom collections import deque@dataclassclass Task: name: str fn: Callable[[], Any] dependencies: list[str] = field(default_factory=list)class CircularDependencyError(Exception): passdef execute_tasks(tasks: list[Task]) -> dict[str, Any]: """Execute tasks in topological order using Kahn's algorithm.""" task_map = {t.name: t for t in tasks} in_degree = {t.name: 0 for t in tasks} dependents: dict[str, list[str]] = {t.name: [] for t in tasks} # Build graph and calculate in-degrees for task in tasks: for dep in task.dependencies: if dep not in task_map: raise ValueError(f"Unknown dependency: {dep}") dependents[dep].append(task.name) in_degree[task.name] += 1 # Start with tasks that have no dependencies queue = deque([name for name, degree in in_degree.items() if degree == 0]) results: dict[str, Any] = {} executed_count = 0 while queue: # All tasks in current queue can run in parallel parallel_batch = list(queue) queue.clear() for name in parallel_batch: task = task_map[name] results[name] = task.fn() executed_count += 1 # Reduce in-degree for dependent tasks for dependent in dependents[name]: in_degree[dependent] -= 1 if in_degree[dependent] == 0: queue.append(dependent) if executed_count != len(tasks): raise CircularDependencyError("Circular dependency detected") return results
Solution 2: DFS with three-color marking
DFS-based with three-color marking (UNVISITED/VISITING/VISITED). Detects cycles by finding back-edges (revisiting a VISITING node). Handles disconnected components by iterating all unvisited nodes.
"""Solution 2: DFS-based topological sort with cycle detectionUses recursion with three-color marking (white/gray/black) for cycle detection."""from dataclasses import dataclass, fieldfrom typing import Callable, Anyfrom enum import Enumclass State(Enum): UNVISITED = 0 VISITING = 1 # Currently in recursion stack (gray) VISITED = 2 # Fully processed (black)@dataclassclass Task: name: str fn: Callable[[], Any] dependencies: list[str] = field(default_factory=list)class CircularDependencyError(Exception): passdef execute_tasks(tasks: list[Task]) -> dict[str, Any]: """Execute tasks using DFS topological sort with cycle detection.""" task_map = {t.name: t for t in tasks} state: dict[str, State] = {t.name: State.UNVISITED for t in tasks} execution_order: list[str] = [] def dfs(name: str) -> None: if state[name] == State.VISITING: raise CircularDependencyError(f"Cycle detected at task: {name}") if state[name] == State.VISITED: return state[name] = State.VISITING task = task_map[name] for dep in task.dependencies: if dep not in task_map: raise ValueError(f"Unknown dependency: {dep}") dfs(dep) state[name] = State.VISITED execution_order.append(name) # Visit all tasks (handles disconnected components) for task in tasks: if state[task.name] == State.UNVISITED: dfs(task.name) # Execute in topological order and collect results results: dict[str, Any] = {} for name in execution_order: results[name] = task_map[name].fn() return results
Solution 3: Concurrent execution with ThreadPoolExecutor
Concurrent execution with ThreadPoolExecutor. Submits all tasks upfront, each task waits on its dependency futures before executing. Uses a gate lock to ensure all futures are registered before any task checks dependencies (avoids race condition).
"""Solution 3: Concurrent execution with ThreadPoolExecutorActually parallelizes independent tasks using threading."""from dataclasses import dataclass, fieldfrom typing import Callable, Anyfrom concurrent.futures import ThreadPoolExecutor, Futurefrom threading import Lockimport time@dataclassclass Task: name: str fn: Callable[[], Any] dependencies: list[str] = field(default_factory=list)class CircularDependencyError(Exception): passdef detect_cycle(tasks: list[Task]) -> bool: """Detect cycles using DFS with coloring.""" task_map = {t.name: t for t in tasks} WHITE, GRAY, BLACK = 0, 1, 2 color = {t.name: WHITE for t in tasks} def dfs(name: str) -> bool: color[name] = GRAY for dep in task_map[name].dependencies: if dep not in task_map: raise ValueError(f"Unknown dependency: {dep}") if color[dep] == GRAY: return True # Back edge = cycle if color[dep] == WHITE and dfs(dep): return True color[name] = BLACK return False return any(color[t.name] == WHITE and dfs(t.name) for t in tasks)def execute_tasks(tasks: list[Task], max_workers: int = 4) -> dict[str, Any]: """Execute tasks concurrently, respecting dependencies.""" if detect_cycle(tasks): raise CircularDependencyError("Circular dependency detected") task_map = {t.name: t for t in tasks} results: dict[str, Any] = {} futures: dict[str, Future] = {} lock = Lock() ready = Lock() # Gate to ensure all futures are registered before execution def run_task(name: str) -> Any: with ready: # Wait until all futures are registered pass task = task_map[name] # Wait for all dependencies to complete for dep in task.dependencies: futures[dep].result() result = task.fn() with lock: results[name] = result return result with ThreadPoolExecutor(max_workers=max_workers) as executor: with ready: # Hold the gate while submitting for task in tasks: futures[task.name] = executor.submit(run_task, task.name) # Gate released, tasks can now execute for future in futures.values(): future.result() return results
Question 7 - Quadtree
Difficulty: 7 / 10
Approximate lines of code: 90 LoC
Tags: data-structures, game/simulation
Description
A quadtree is a tree data structure where each internal node has exactly four children, used to partition a 2D space by recursively subdividing it into four quadrants. It is commonly used in collision detection (games), spatial indexing (GIS), and image compression. Each node represents a rectangular region; when a node contains too many objects, it splits into four children (NW, NE, SW, SE). The key insight is that spatial queries can skip entire subtrees whose bounding boxes do not intersect the query region.
Internal state example after inserting points (10,10), (20,20), (80,80):
Root [boundary: (0,0)-(100,100), capacity: 4]
points: [(10,10), (20,20), (80,80)] # Not yet subdivided
After inserting 2 more points, capacity exceeded:
Root [divided=True]
NW [boundary: (0,50)-(50,100)] -> points: []
NE [boundary: (50,50)-(100,100)] -> points: [(80,80)]
SW [boundary: (0,0)-(50,50)] -> points: [(10,10), (20,20)]
SE [boundary: (50,0)-(100,50)] -> points: []
Part A: Point Storage and Region Queries
Problem: Part A
Implement a quadtree that stores 2D points and supports rectangle queries.
qt = QuadTree(boundary=Rectangle(50, 50, 50, 50), capacity=4)qt.insert(Point(10, 10))qt.insert(Point(20, 20))qt.insert(Point(80, 80))# Query: find all points in region [(5,5) to (25,25)]results = qt.query(Rectangle(15, 15, 10, 10))# Returns: [Point(10,10), Point(20,20)]# Internal state: if capacity not exceeded, points stored in leaf# If exceeded, node subdivides and redistributes points to children
Key operations:
insert(point): Add point to appropriate leaf; subdivide if capacity exceeded
query(region): Return all points within the given rectangle
subdivide(): Split node into four children (NW, NE, SW, SE)
Part B: Rectangle Storage and Collision Detection
Problem: Part B
Extend the quadtree to store rectangles (axis-aligned bounding boxes) instead of points. Find all stored rectangles that intersect a query rectangle.
Key difference from Part A: A rectangle may span multiple quadrants, so it must be inserted into all children it intersects (or kept at the parent level).
Part C: Deletion and Rebalancing
Problem: Part C
Implement deletion of objects and optional merging of sparse nodes back into their parent.
qt = QuadTree(Bounds(0, 0, 100), capacity=2)for i in range(20): qt.insert(Point(i*4+5, i*4+5, id=i))assert qt.depth() >= 3 # Tree has grown deep# Rebuild with larger capacity to reduce depthrebuilt = qt.rebuild(new_capacity=8)assert rebuilt.depth() < qt.depth()# Check if tree needs rebalancingif qt.needs_rebalance(max_depth=10): qt = qt.rebuild()
Consider: When should you merge children back? A common heuristic is when total objects in all four children falls below capacity.
Interview comments
Interview comments
Edge cases to probe:
What happens when a point falls exactly on a quadrant boundary?
How do you handle objects that span multiple quadrants?
What if all points cluster in one corner (degenerate case)?
Does query handle regions that extend outside the tree boundary?
Common mistakes:
Off-by-one errors in boundary checks (using < vs <=)
Infinite recursion when a point sits exactly on a subdivision boundary
Forgetting to redistribute existing points after subdivision
Not pruning search when query region does not intersect node boundary
For rectangles: only inserting into one child instead of all overlapping children
Code solutions
Code solutions
Solution 1 implements a basic point-based quadtree using center+half-width rectangle representation with recursive subdivision. Solution 2 extends to rectangle storage using min/max corner AABB representation for collision detection, where objects can span multiple quadrants. Solution 3 adds rebuild and rebalancing capabilities to maintain tree efficiency after many operations. The key difference is what they store (points vs rectangles) and whether they support tree maintenance operations.
Solution 1: Basic point-based quadtree
Basic point-based quadtree using center+half-width rectangle representation. Stores points in leaf nodes until capacity is exceeded, then subdivides and redistributes. Query recursively checks intersection before descending.
"""Quadtree implementation - Basic point-based quadtree with region queries."""from dataclasses import dataclass, fieldfrom typing import Optional@dataclassclass Point: x: float y: float data: any = None@dataclassclass Rectangle: x: float # center x y: float # center y w: float # half-width h: float # half-height def contains(self, point: Point) -> bool: return (self.x - self.w <= point.x <= self.x + self.w and self.y - self.h <= point.y <= self.y + self.h) def intersects(self, other: "Rectangle") -> bool: return not (other.x - other.w > self.x + self.w or other.x + other.w < self.x - self.w or other.y - other.h > self.y + self.h or other.y + other.h < self.y - self.h)@dataclassclass QuadTree: boundary: Rectangle capacity: int = 4 points: list[Point] = field(default_factory=list) divided: bool = False nw: Optional["QuadTree"] = None ne: Optional["QuadTree"] = None sw: Optional["QuadTree"] = None se: Optional["QuadTree"] = None def subdivide(self) -> None: x, y, w, h = self.boundary.x, self.boundary.y, self.boundary.w / 2, self.boundary.h / 2 self.nw = QuadTree(Rectangle(x - w, y + h, w, h), self.capacity) self.ne = QuadTree(Rectangle(x + w, y + h, w, h), self.capacity) self.sw = QuadTree(Rectangle(x - w, y - h, w, h), self.capacity) self.se = QuadTree(Rectangle(x + w, y - h, w, h), self.capacity) self.divided = True def insert(self, point: Point) -> bool: if not self.boundary.contains(point): return False if len(self.points) < self.capacity and not self.divided: self.points.append(point) return True if not self.divided: self.subdivide() for p in self.points: self._insert_into_children(p) self.points = [] return self._insert_into_children(point) def _insert_into_children(self, point: Point) -> bool: return (self.nw.insert(point) or self.ne.insert(point) or self.sw.insert(point) or self.se.insert(point)) def query(self, region: Rectangle, found: list[Point] = None) -> list[Point]: if found is None: found = [] if not self.boundary.intersects(region): return found for p in self.points: if region.contains(p): found.append(p) if self.divided: self.nw.query(region, found) self.ne.query(region, found) self.sw.query(region, found) self.se.query(region, found) return found
Solution 2: Rectangle-based quadtree with collision detection
Rectangle-based quadtree using min/max corner AABB representation. Designed for collision detection with entities that have bounding boxes. Objects spanning multiple quadrants are inserted into all relevant children.
"""Quadtree implementation - Rectangle storage with collision detection."""from dataclasses import dataclass, fieldfrom typing import Optional@dataclass(frozen=True)class AABB: """Axis-aligned bounding box.""" min_x: float min_y: float max_x: float max_y: float @property def center(self) -> tuple[float, float]: return ((self.min_x + self.max_x) / 2, (self.min_y + self.max_y) / 2) def contains_point(self, x: float, y: float) -> bool: return self.min_x <= x <= self.max_x and self.min_y <= y <= self.max_y def intersects(self, other: "AABB") -> bool: return not (self.max_x < other.min_x or self.min_x > other.max_x or self.max_y < other.min_y or self.min_y > other.max_y)@dataclassclass Entity: id: int bounds: AABB@dataclassclass QuadTree: bounds: AABB max_objects: int = 4 max_depth: int = 8 depth: int = 0 objects: list[Entity] = field(default_factory=list) children: Optional[list["QuadTree"]] = None def _split(self) -> None: cx, cy = self.bounds.center self.children = [ QuadTree(AABB(self.bounds.min_x, cy, cx, self.bounds.max_y), self.max_objects, self.max_depth, self.depth + 1), QuadTree(AABB(cx, cy, self.bounds.max_x, self.bounds.max_y), self.max_objects, self.max_depth, self.depth + 1), QuadTree(AABB(self.bounds.min_x, self.bounds.min_y, cx, cy), self.max_objects, self.max_depth, self.depth + 1), QuadTree(AABB(cx, self.bounds.min_y, self.bounds.max_x, cy), self.max_objects, self.max_depth, self.depth + 1), ] def insert(self, entity: Entity) -> bool: if not self.bounds.intersects(entity.bounds): return False if self.children is None and (len(self.objects) < self.max_objects or self.depth >= self.max_depth): self.objects.append(entity) return True if self.children is None: self._split() for obj in self.objects: for child in self.children: child.insert(obj) self.objects = [] for child in self.children: child.insert(entity) return True def query_collisions(self, target: AABB) -> list[Entity]: results = [] if not self.bounds.intersects(target): return results for obj in self.objects: if obj.bounds.intersects(target): results.append(obj) if self.children: for child in self.children: results.extend(child.query_collisions(target)) return results def clear(self) -> None: self.objects.clear() self.children = None
Solution 3: Point-based quadtree with rebuild/rebalance
Point-based quadtree with rebuild and rebalance support. Tracks total point count, supports iterating all points, and can rebuild with different capacity to reduce tree depth after many insertions/deletions.
"""Quadtree implementation - With rebuild/rebalance support."""from dataclasses import dataclass, fieldfrom typing import Optional, Iterator@dataclassclass Point: x: float y: float id: int = 0@dataclassclass Bounds: x: float y: float size: float # Square region for simplicity def contains(self, p: Point) -> bool: return self.x <= p.x < self.x + self.size and self.y <= p.y < self.y + self.size def intersects_rect(self, rx: float, ry: float, rw: float, rh: float) -> bool: return not (rx > self.x + self.size or rx + rw < self.x or ry > self.y + self.size or ry + rh < self.y)@dataclassclass QuadTree: bounds: Bounds capacity: int = 4 points: list[Point] = field(default_factory=list) children: Optional[list["QuadTree"]] = None _total_points: int = 0 def _subdivide(self) -> None: x, y, half = self.bounds.x, self.bounds.y, self.bounds.size / 2 self.children = [ QuadTree(Bounds(x, y + half, half), self.capacity), QuadTree(Bounds(x + half, y + half, half), self.capacity), QuadTree(Bounds(x, y, half), self.capacity), QuadTree(Bounds(x + half, y, half), self.capacity), ] def insert(self, point: Point) -> bool: if not self.bounds.contains(point): return False self._total_points += 1 if self.children is None and len(self.points) < self.capacity: self.points.append(point) return True if self.children is None: self._subdivide() for p in self.points: for child in self.children: if child.insert(p): break self.points = [] for child in self.children: if child.insert(point): return True return False def query_region(self, rx: float, ry: float, rw: float, rh: float) -> list[Point]: results = [] if not self.bounds.intersects_rect(rx, ry, rw, rh): return results for p in self.points: if rx <= p.x <= rx + rw and ry <= p.y <= ry + rh: results.append(p) if self.children: for child in self.children: results.extend(child.query_region(rx, ry, rw, rh)) return results def all_points(self) -> Iterator[Point]: yield from self.points if self.children: for child in self.children: yield from child.all_points() def rebuild(self, new_capacity: Optional[int] = None) -> "QuadTree": """Rebuild tree, optionally with new capacity.""" cap = new_capacity if new_capacity else self.capacity new_tree = QuadTree(self.bounds, cap) for p in self.all_points(): new_tree.insert(p) return new_tree def depth(self) -> int: if self.children is None: return 1 return 1 + max(child.depth() for child in self.children) def needs_rebalance(self, max_depth: int = 10) -> bool: return self.depth() > max_depth
Question 8 - Two-Phase Commit
Difficulty: 10 / 10
Approximate lines of code: 90 LoC
Tags: distributed-systems, concurrency
Description
Concurrency Cheat Sheet
This problem involves concurrency. Here is a cheat sheet of relevant Python syntax. You may not need all of these; this list is not exhaustive or suggestive for the problem.
# threading.Threadt = threading.Thread(target=func, args=(arg1, arg2))t.start()t.join()# threading.Locklock = threading.Lock()lock.acquire()lock.release()with lock: # critical section# threading.Conditioncond = threading.Condition()with cond: cond.wait() # release lock, wait for notify cond.notify() # wake one waiter cond.notify_all() # wake all waiters# threading.Semaphoresem = threading.Semaphore(value=3)sem.acquire()sem.release()# threading.Eventevent = threading.Event()event.set() # signalevent.clear() # resetevent.wait() # block until setevent.is_set() # check without blocking# queue.Queue (thread-safe)from queue import Queueq = Queue()q.put(item)item = q.get() # blocksitem = q.get_nowait() # raises Empty if none# Atomic operations concept# Python's GIL makes single bytecode ops atomic# For true atomics, use: threading.Lock or ctypes/cffi# CAS pattern: if current == expected: current = new
Two-Phase Commit (2PC) is a protocol for coordinating distributed transactions across multiple nodes (databases, services, etc.). When a transaction spans multiple systems, you need all of them to either commit or abort together - partial commits corrupt data. The coordinator orchestrates the protocol, and participants vote on whether they can commit.
The two phases are: (1) Prepare: coordinator asks all participants “can you commit?”, participants lock resources and vote YES or NO; (2) Commit/Abort: if all voted YES, coordinator tells everyone to COMMIT, otherwise tells everyone to ABORT. The key insight: once a participant votes YES, it must be able to commit later (resources stay locked until the coordinator’s decision arrives).
Part A: Happy Path
Problem: Part A
Implement Coordinator and Participant classes. The coordinator sends PREPARE to all participants, collects votes, then sends COMMIT if all voted YES or ABORT if any voted NO.
p2 = Participant("db2", fail_on_prepare=True) # Will vote NOcoordinator = Coordinator([p1, p2, p3])result = coordinator.execute_transaction("txn-002")# Phase 1: p2 votes ABORT# Phase 2: coordinator sends ABORT to allresult # False# All participants in ABORTED state
Part B: Participant Failures
Problem: Part B
Handle failures during the protocol. Key rules:
Failure during PREPARE: Simply abort the transaction (participant couldn’t prepare)
Failure during COMMIT (after all voted YES): Coordinator must retry indefinitely - the decision is made and MUST be honored
# Participant fails during commit phasep1 = Participant("db1")p2 = Participant("db2", fail_on_commit=True) # Prepare succeeds, commit failscoordinator = Coordinator([p1, p2])result = coordinator.run_2pc("txn-003")# Phase 1: Both vote COMMIT# Phase 2: p2.commit() fails!# Coordinator MUST retry p2.commit() until it succeeds# Cannot abort - p1 already committed!
Timeout handling:
p1 = Participant("db1", simulate_timeout=True)# Phase 1: p1 times out during prepare# Coordinator aborts entire transaction (safe - no one committed yet)
Part C: Write-Ahead Logging
Problem: Part C
Add durability. Before responding to any message, participants and coordinator must log their state. This enables crash recovery.
# Participant logging:def prepare(self, txn_id): # Do work to prepare... self.log.write(txn_id, "PREPARED") # LOG BEFORE responding return Vote.COMMITdef commit(self, txn_id): self.log.write(txn_id, "COMMITTED") # LOG BEFORE applying # Apply changes...# Coordinator logging:def run_2pc(self, txn_id): votes = [p.prepare(txn_id) for p in participants] decision = COMMIT if all_yes else ABORT self.log.write(txn_id, decision) # LOG DECISION before sending # Now send commit/abort to all participants
Recovery after crash:
# Participant recovery:state = log.get_state(txn_id)if state == PREPARED: # Ask coordinator for decision (might block!)elif state == COMMITTED: # Re-apply if needed (idempotent)elif state == None: # Never prepared, safe to abort
Interview comments
Interview comments
Edge cases to probe:
What happens if coordinator crashes after all participants voted COMMIT but before sending commit? (Participants are BLOCKED - this is 2PC’s fundamental limitation)
What if a participant crashes after voting YES but before receiving commit? (Must recover and ask coordinator for decision)
Can you ever abort after someone voted YES? (Only if coordinator also failed and you use a timeout-based presumed abort)
Common mistakes:
Not retrying forever on commit failure (once decided, must commit)
Forgetting write-ahead logging (crashes corrupt data without it)
Not understanding the blocking problem (2PC can block indefinitely)
Confusing 2PC with consensus (2PC is for transactions, not leader election)
Sending commit before logging the decision (crash loses decision)
Code solutions
Code solutions
Solution 1 is a basic synchronous implementation with failure simulation flags, showing the core protocol flow without logging or retries. Solution 2 adds timeout handling and write-ahead logging for crash recovery, with participants logging state before responding. Solution 3 is an async implementation using asyncio for concurrent participant coordination, with retry logic for the commit phase. These differ in their handling of failures (none vs timeouts vs retries) and execution model (sync vs async). Core techniques: state machines, write-ahead logging, timeout handling, retry with backoff.
Solution 1: Basic synchronous implementation
Basic synchronous implementation with failure simulation flags. Clean state machine for participants. No logging or retry logic - just the core protocol flow.
"""Two-Phase Commit Protocol - Basic ImplementationClean synchronous implementation with failure simulation."""from dataclasses import dataclassfrom enum import Enumfrom typing import Callableimport randomclass Vote(Enum): COMMIT = "commit" ABORT = "abort"class ParticipantState(Enum): INITIAL = "initial" PREPARED = "prepared" COMMITTED = "committed" ABORTED = "aborted"@dataclassclass Participant: name: str state: ParticipantState = ParticipantState.INITIAL fail_on_prepare: bool = False fail_on_commit: bool = False def prepare(self, transaction_id: str) -> Vote: if self.fail_on_prepare: return Vote.ABORT self.state = ParticipantState.PREPARED return Vote.COMMIT def commit(self, transaction_id: str) -> bool: if self.fail_on_commit: return False self.state = ParticipantState.COMMITTED return True def abort(self, transaction_id: str) -> bool: self.state = ParticipantState.ABORTED return True@dataclassclass Coordinator: participants: list[Participant] transaction_id: str = "" def execute_transaction(self, txn_id: str) -> bool: self.transaction_id = txn_id # Phase 1: Prepare votes = [] for p in self.participants: vote = p.prepare(txn_id) votes.append(vote) # Decision all_commit = all(v == Vote.COMMIT for v in votes) # Phase 2: Commit or Abort if all_commit: for p in self.participants: p.commit(txn_id) return True else: for p in self.participants: p.abort(txn_id) return False
Solution 2: With timeout and recovery
Adds timeout handling and write-ahead logging for crash recovery. Participants log state transitions before responding. Coordinator logs decision before sending commit/abort. Includes recovery method to restore state from log.
"""Two-Phase Commit Protocol - With Timeout and RecoveryHandles coordinator/participant failures with timeout simulation."""from dataclasses import dataclass, fieldfrom enum import Enumfrom typing import Optionalimport timeclass TxnState(Enum): INIT = "init" PREPARED = "prepared" COMMITTED = "committed" ABORTED = "aborted" TIMEOUT = "timeout"@dataclassclass TransactionLog: """Write-ahead log for crash recovery.""" entries: list[tuple[str, str, TxnState]] = field(default_factory=list) def write(self, txn_id: str, participant: str, state: TxnState) -> None: self.entries.append((txn_id, participant, state)) def get_state(self, txn_id: str, participant: str) -> Optional[TxnState]: for t, p, s in reversed(self.entries): if t == txn_id and p == participant: return s return None@dataclassclass Participant: name: str log: TransactionLog = field(default_factory=TransactionLog) state: TxnState = TxnState.INIT simulate_timeout: bool = False def prepare(self, txn_id: str, timeout_ms: int = 1000) -> TxnState: if self.simulate_timeout: return TxnState.TIMEOUT self.state = TxnState.PREPARED self.log.write(txn_id, self.name, TxnState.PREPARED) return TxnState.PREPARED def commit(self, txn_id: str) -> TxnState: self.state = TxnState.COMMITTED self.log.write(txn_id, self.name, TxnState.COMMITTED) return TxnState.COMMITTED def abort(self, txn_id: str) -> TxnState: self.state = TxnState.ABORTED self.log.write(txn_id, self.name, TxnState.ABORTED) return TxnState.ABORTED def recover(self, txn_id: str) -> Optional[TxnState]: return self.log.get_state(txn_id, self.name)@dataclassclass Coordinator: participants: list[Participant] log: TransactionLog = field(default_factory=TransactionLog) timeout_ms: int = 1000 def run_2pc(self, txn_id: str) -> bool: # Phase 1: Prepare with timeout handling prepared = [] for p in self.participants: result = p.prepare(txn_id, self.timeout_ms) if result == TxnState.TIMEOUT: self._abort_all(txn_id) return False prepared.append(result == TxnState.PREPARED) # Log decision before Phase 2 decision = all(prepared) self.log.write(txn_id, "coordinator", TxnState.COMMITTED if decision else TxnState.ABORTED) # Phase 2: Execute decision if decision: for p in self.participants: p.commit(txn_id) return True else: self._abort_all(txn_id) return False def _abort_all(self, txn_id: str) -> None: for p in self.participants: p.abort(txn_id)
Solution 3: Async with retries
Async implementation with concurrent participant coordination and retry logic. Uses asyncio for parallel prepare/commit phases. Implements retry with backoff for commit phase failures (once decided, must succeed).
"""Two-Phase Commit Protocol - Async with RetriesProduction-style implementation with async coordination and retry logic."""from dataclasses import dataclass, fieldfrom enum import Enumfrom typing import Callable, Optionalimport asyncioclass Phase(Enum): PREPARE = "prepare" COMMIT = "commit" ABORT = "abort"class Result(Enum): OK = "ok" FAIL = "fail" TIMEOUT = "timeout"@dataclassclass Participant: id: str fail_phase: Optional[Phase] = None delay_ms: int = 0 async def prepare(self) -> Result: await asyncio.sleep(self.delay_ms / 1000) if self.fail_phase == Phase.PREPARE: return Result.FAIL return Result.OK async def commit(self) -> Result: await asyncio.sleep(self.delay_ms / 1000) if self.fail_phase == Phase.COMMIT: return Result.FAIL return Result.OK async def abort(self) -> Result: return Result.OK@dataclassclass TwoPhaseCommit: participants: list[Participant] timeout_sec: float = 1.0 max_retries: int = 3 async def execute(self, txn_id: str) -> bool: # Phase 1: Prepare all participants concurrently prepare_results = await self._run_phase(Phase.PREPARE) if all(r == Result.OK for r in prepare_results): # Phase 2a: Commit all commit_results = await self._run_phase_with_retry(Phase.COMMIT) return all(r == Result.OK for r in commit_results) else: # Phase 2b: Abort all await self._run_phase(Phase.ABORT) return False async def _run_phase(self, phase: Phase) -> list[Result]: async def run_one(p: Participant) -> Result: try: method = getattr(p, phase.value) return await asyncio.wait_for(method(), self.timeout_sec) except asyncio.TimeoutError: return Result.TIMEOUT tasks = [run_one(p) for p in self.participants] return await asyncio.gather(*tasks) async def _run_phase_with_retry(self, phase: Phase) -> list[Result]: """Retry phase for participants that failed (commit must eventually succeed).""" results = [Result.FAIL] * len(self.participants) for attempt in range(self.max_retries): for i, p in enumerate(self.participants): if results[i] != Result.OK: method = getattr(p, phase.value) try: results[i] = await asyncio.wait_for(method(), self.timeout_sec) except asyncio.TimeoutError: results[i] = Result.TIMEOUT if all(r == Result.OK for r in results): break return results
Question 9 - Connection Pool
Difficulty: 8 / 10
Approximate lines of code: 100 LoC
Tags: concurrency, storage
Description
Concurrency Cheat Sheet
This problem involves concurrency. Here is a cheat sheet of relevant Python syntax. You may not need all of these; this list is not exhaustive or suggestive for the problem.
A connection pool manages reusable database connections to avoid the overhead of creating new connections for every request. The pool maintains a fixed maximum size, hands out connections on checkout, and accepts them back on checkin. Key challenges include thread-safety (multiple threads requesting connections concurrently), validation (detecting broken connections), and timeouts (what happens when the pool is exhausted).
Internally, you need: a collection of available connections, a set tracking checked-out connections, a lock for thread-safety, and a condition variable for blocking when the pool is exhausted. Connections should be validated before being returned to ensure they’re still usable.
Part A: Basic Pool
Problem: Part A
Implement checkout() that returns a connection (creating one if needed, up to max size) and checkin(conn) that returns a connection to the pool. If all connections are in use and max size is reached, block until one becomes available.
pool = ConnectionPool(max_size=2)c1 = pool.checkout() # Creates new connection (id=0)c2 = pool.checkout() # Creates new connection (id=1)# Internal state:# available: []# in_use: {Connection(0), Connection(1)}# total: 2 (at max)c3 = pool.checkout(timeout=0.1) # Returns None (pool exhausted, timeout)pool.checkin(c1)# Internal state:# available: [Connection(0)]# in_use: {Connection(1)}c4 = pool.checkout() # Returns Connection(0) from available pool
Part B: Validation
Problem: Part B
Before returning a connection on checkout, verify it’s still valid. If invalid, discard it and try the next available connection (or create a new one). Also validate on checkin - don’t return broken connections to the pool.
pool = ConnectionPool(max_size=2, validator=lambda c: c.is_valid)c1 = pool.checkout() # id=0c1.is_valid = False # Simulate broken connectionpool.checkin(c1) # Connection discarded, not returned to pool# Internal state:# available: []# in_use: {}# total connections: 0 (broken one was discarded)c2 = pool.checkout() # Creates new connection (id=1)
Part C: Timeout
Problem: Part C
Implement blocking with timeout when the pool is exhausted. Use a condition variable to efficiently wait for a connection to become available rather than busy-polling.
pool = ConnectionPool(max_size=1, timeout=5.0)c1 = pool.checkout() # Gets connection immediately# In another thread:c2 = pool.checkout(timeout=2.0)# Blocks for up to 2 seconds waiting for c1 to be returned# Returns None if timeout expires# Back in first thread:pool.checkin(c1) # Wakes up waiting thread via condition.notify()
Interview comments
Interview comments
Edge cases to probe:
What if someone calls checkin twice with the same connection? (Guard against double-checkin)
What if checkin is called with a connection that was never checked out? (Ignore or error)
What about connection staleness (idle too long)? (Add max_idle_time check)
What happens if validation itself throws an exception? (Treat as invalid)
Common mistakes:
Race condition: checking pool size and creating connection not atomic
Holding lock during connection creation (blocks other threads unnecessarily)
Not using condition variable (busy-waiting wastes CPU)
Double-checkin causing pool to grow beyond max_size
Resource leak if validation fails after checkout but before user gets connection
Code solutions
Code solutions
Solutions Overview
Solution 1 uses Lock and Condition with a list/set for available and in-use tracking. Solution 2 switches to queue.Queue with context manager support and staleness detection. Solution 3 uses Semaphore for cleaner concurrency limiting with deque for FIFO reuse and health checks. The key difference is the synchronization primitive: Condition for flexible blocking, Queue for built-in thread-safety, Semaphore for simple capacity limiting. Core techniques: locks, condition variables, semaphores, context managers.
Solution 1: Lock and Condition
Basic implementation using threading.Lock and Condition. Uses list for available connections and set for in-use tracking. Simple and correct with proper condition variable usage for blocking.
"""Connection Pool - Solution 1: Basic Threading with LockUses a simple list-based pool with threading.Lock for synchronization."""from dataclasses import dataclass, fieldfrom threading import Lock, Conditionfrom typing import Optional, Callableimport time@dataclassclass Connection: """Represents a database connection.""" id: int created_at: float = field(default_factory=time.time) is_valid: bool = True def __hash__(self) -> int: return hash(self.id) def __eq__(self, other: object) -> bool: return isinstance(other, Connection) and self.id == other.id def close(self) -> None: self.is_valid = False@dataclassclass ConnectionPool: """Thread-safe connection pool with timeout and validation.""" max_size: int timeout: float = 30.0 validator: Callable[[Connection], bool] = lambda c: c.is_valid _available: list[Connection] = field(default_factory=list, init=False) _in_use: set[Connection] = field(default_factory=set, init=False) _lock: Lock = field(default_factory=Lock, init=False) _condition: Condition = field(default_factory=Condition, init=False) _next_id: int = field(default=0, init=False) def _create_connection(self) -> Connection: conn = Connection(id=self._next_id) self._next_id += 1 return conn def checkout(self, timeout: Optional[float] = None) -> Optional[Connection]: """Get a connection from the pool. Returns None if timeout expires.""" wait_timeout = timeout if timeout is not None else self.timeout deadline = time.time() + wait_timeout with self._condition: while True: # Try to get a valid connection from available pool while self._available: conn = self._available.pop() if self.validator(conn): self._in_use.add(conn) return conn conn.close() # Discard invalid connection # Create new connection if under limit if len(self._in_use) < self.max_size: conn = self._create_connection() self._in_use.add(conn) return conn # Wait for a connection to be returned remaining = deadline - time.time() if remaining <= 0: return None self._condition.wait(timeout=remaining) def checkin(self, conn: Connection) -> None: """Return a connection to the pool.""" with self._condition: if conn in self._in_use: self._in_use.remove(conn) if self.validator(conn): self._available.append(conn) else: conn.close() self._condition.notify() def size(self) -> int: """Return total connections (available + in use).""" with self._lock: return len(self._available) + len(self._in_use)
Solution 2: Queue-Based with Context Manager
Queue-based implementation with context manager support. Adds staleness detection (max_idle_time) and provides a with pool.connection() interface for automatic checkout/checkin with guaranteed cleanup.
"""Connection Pool - Solution 2: Queue-Based with Context ManagerUses queue.Queue for thread-safety and provides context manager interface."""from dataclasses import dataclass, fieldfrom queue import Queue, Empty, Fullfrom threading import Lockfrom typing import Optional, Callable, Iteratorfrom contextlib import contextmanagerimport time@dataclassclass Connection: """Represents a pooled connection.""" id: int created_at: float = field(default_factory=time.time) last_used: float = field(default_factory=time.time) _closed: bool = field(default=False, init=False) def is_valid(self) -> bool: return not self._closed def close(self) -> None: self._closed = True def touch(self) -> None: self.last_used = time.time()@dataclassclass ConnectionPool: """Queue-based connection pool with context manager support.""" max_size: int timeout: float = 30.0 max_idle_time: float = 300.0 validator: Callable[[Connection], bool] = field(default=lambda c: c.is_valid()) _pool: Queue[Connection] = field(init=False) _in_use: int = field(default=0, init=False) _next_id: int = field(default=0, init=False) _lock: Lock = field(default_factory=Lock, init=False) def __post_init__(self) -> None: self._pool = Queue(maxsize=self.max_size) def _create_connection(self) -> Connection: conn = Connection(id=self._next_id) self._next_id += 1 return conn def _is_stale(self, conn: Connection) -> bool: return time.time() - conn.last_used > self.max_idle_time def checkout(self, timeout: Optional[float] = None) -> Optional[Connection]: """Acquire a connection from the pool.""" wait_time = timeout if timeout is not None else self.timeout deadline = time.time() + wait_time while True: # Try to get from pool first try: conn = self._pool.get_nowait() if self.validator(conn) and not self._is_stale(conn): conn.touch() with self._lock: self._in_use += 1 return conn conn.close() continue # Try again except Empty: pass # Pool empty - try to create new connection with self._lock: if self._in_use + self._pool.qsize() < self.max_size: self._in_use += 1 return self._create_connection() # At max capacity, wait for return remaining = deadline - time.time() if remaining <= 0: return None try: conn = self._pool.get(timeout=remaining) if self.validator(conn) and not self._is_stale(conn): conn.touch() with self._lock: self._in_use += 1 return conn conn.close() except Empty: return None def checkin(self, conn: Connection) -> None: """Return a connection to the pool.""" with self._lock: self._in_use -= 1 if not self.validator(conn): conn.close() return conn.touch() try: self._pool.put_nowait(conn) except Full: conn.close() @contextmanager def connection(self, timeout: Optional[float] = None) -> Iterator[Connection]: """Context manager for automatic checkout/checkin.""" conn = self.checkout(timeout=timeout) if conn is None: raise TimeoutError("Could not acquire connection from pool") try: yield conn finally: self.checkin(conn)
Solution 3: Semaphore-Based with Health Checks
Semaphore-based implementation with health checks. Uses semaphore to limit concurrent connections (cleaner than manual counting). Includes deque for FIFO connection reuse and explicit health check method on connections.
"""Connection Pool - Solution 3: Semaphore-Based with Health ChecksUses semaphores for limiting concurrent connections and periodic health checks."""from dataclasses import dataclass, fieldfrom threading import Semaphore, Lockfrom typing import Optional, Callablefrom collections import dequeimport time@dataclassclass Connection: """Connection with health check support.""" id: int created_at: float = field(default_factory=time.time) _healthy: bool = field(default=True, init=False) def ping(self) -> bool: """Simulate health check.""" return self._healthy def execute(self, query: str) -> str: """Simulate query execution.""" if not self._healthy: raise ConnectionError("Connection is unhealthy") return f"Result from conn-{self.id}: {query}" def close(self) -> None: self._healthy = False@dataclassclass ConnectionPool: """Semaphore-based pool with health checking.""" max_size: int timeout: float = 30.0 health_check: Callable[[Connection], bool] = field(default=lambda c: c.ping()) _semaphore: Semaphore = field(init=False) _connections: deque[Connection] = field(default_factory=deque, init=False) _checked_out: set[int] = field(default_factory=set, init=False) _lock: Lock = field(default_factory=Lock, init=False) _next_id: int = field(default=0, init=False) def __post_init__(self) -> None: self._semaphore = Semaphore(self.max_size) def _create_connection(self) -> Connection: conn = Connection(id=self._next_id) self._next_id += 1 return conn def checkout(self, timeout: Optional[float] = None) -> Optional[Connection]: """Acquire a healthy connection.""" wait_time = timeout if timeout is not None else self.timeout if not self._semaphore.acquire(timeout=wait_time): return None with self._lock: # Try to find a healthy connection in the pool attempts = len(self._connections) while attempts > 0: if not self._connections: break conn = self._connections.popleft() attempts -= 1 if self.health_check(conn): self._checked_out.add(conn.id) return conn conn.close() # Create new connection conn = self._create_connection() self._checked_out.add(conn.id) return conn def checkin(self, conn: Connection) -> None: """Return connection to pool after health check.""" with self._lock: if conn.id not in self._checked_out: return # Already returned or never checked out self._checked_out.remove(conn.id) if self.health_check(conn): self._connections.append(conn) else: conn.close() self._semaphore.release() def stats(self) -> dict[str, int]: """Return pool statistics.""" with self._lock: return { "available": len(self._connections), "in_use": len(self._checked_out), "total_created": self._next_id, }
Question 10 - Trie Autocomplete
Difficulty: 3 / 10
Approximate lines of code: 80 LoC
Tags: data-structures
Description
A trie (prefix tree) is a tree data structure where each node represents a character, and paths from root to nodes spell out prefixes. Each node contains a dictionary mapping characters to child nodes, plus metadata like is_end (marks complete words) and optional ranking data (frequency, recency). Tries enable O(k) prefix lookup where k is the prefix length, making them ideal for autocomplete systems.
The core operations are: (1) insert by walking/creating nodes for each character, (2) search by walking nodes and checking is_end, and (3) autocomplete by finding the prefix node then collecting all words in that subtree via DFS.
Part A: Basic Trie Operations
Problem: Part A
Implement insert(word) and search(word). Insert creates nodes as needed and marks the final node as a word ending. Search traverses nodes and returns True only if the path exists AND ends at a complete word.
Implement autocomplete(prefix) that returns all words starting with the given prefix. Navigate to the prefix node, then DFS to collect all complete words in that subtree.
Track how many times each word is inserted and rank autocomplete results by frequency (most frequent first). Repeated insertions of the same word increment its count.
trie = Trie()trie.insert("app")trie.insert("app")trie.insert("app") # frequency=3trie.insert("apple") # frequency=1trie.insert("application") # frequency=1suggestions = trie.autocomplete("ap", limit=2)# Returns ["app", "apple"] - "app" first due to higher frequency
Interview comments
Interview comments
Edge cases to probe:
Empty string prefix (should return all words)?
Case sensitivity handling?
What if limit exceeds available matches?
Single-character words?
Common mistakes:
Forgetting is_end flag (treating every node as a word)
Inefficient string concatenation in DFS (use list + join)
Not handling empty prefix case
Returning nodes instead of reconstructed strings
Code solutions
Code solutions
Solution 1 implements a basic trie with frequency-based ranking, collecting all words via DFS then sorting. Solution 2 uses recency-based ranking with a monotonic timestamp counter, also supporting soft delete. Solution 3 combines frequency and recency into a weighted hybrid score and uses a min-heap to efficiently maintain top-k results during collection. These vary in their ranking strategy: pure frequency vs pure recency vs weighted hybrid. Core techniques: tries, DFS traversal, min-heaps for top-k.
Solution 1: Basic Trie with Frequency Ranking
A basic trie with frequency-based ranking. Nodes track is_end and frequency. Autocomplete collects all words via DFS then sorts by frequency descending.
"""Trie Autocomplete - Solution 1: Basic Trie with Frequency RankingA straightforward trie implementation using a dictionary for children.Ranks suggestions by insertion frequency."""from dataclasses import dataclass, fieldfrom typing import Optional@dataclassclass TrieNode: children: dict[str, "TrieNode"] = field(default_factory=dict) is_end: bool = False frequency: int = 0class Trie: def __init__(self) -> None: self.root = TrieNode() def insert(self, word: str) -> None: """Insert a word into the trie. Increments frequency if word exists.""" node = self.root for char in word: if char not in node.children: node.children[char] = TrieNode() node = node.children[char] node.is_end = True node.frequency += 1 def search(self, word: str) -> bool: """Return True if word exists in the trie.""" node = self._find_node(word) return node is not None and node.is_end def starts_with(self, prefix: str) -> bool: """Return True if any word starts with the given prefix.""" return self._find_node(prefix) is not None def autocomplete(self, prefix: str, limit: int = 10) -> list[str]: """Return words with given prefix, ranked by frequency (descending).""" node = self._find_node(prefix) if node is None: return [] results: list[tuple[str, int]] = [] self._collect_words(node, prefix, results) results.sort(key=lambda x: -x[1]) return [word for word, _ in results[:limit]] def _find_node(self, prefix: str) -> Optional[TrieNode]: """Traverse to the node representing the prefix.""" node = self.root for char in prefix: if char not in node.children: return None node = node.children[char] return node def _collect_words( self, node: TrieNode, prefix: str, results: list[tuple[str, int]] ) -> None: """DFS to collect all words under a node.""" if node.is_end: results.append((prefix, node.frequency)) for char, child in node.children.items(): self._collect_words(child, prefix + char, results)
Solution 2: Trie with Recency Ranking
A recency-based trie where more recently inserted/accessed words rank higher. Uses a monotonic timestamp counter. Also implements soft delete by clearing is_end.
"""Trie Autocomplete - Solution 2: Trie with Recency RankingUses a timestamp-based approach to rank suggestions by recency.More recent insertions appear first in autocomplete results."""from dataclasses import dataclass, fieldfrom typing import Optional@dataclassclass TrieNode: children: dict[str, "TrieNode"] = field(default_factory=dict) is_end: bool = False last_used: int = 0class RecencyTrie: def __init__(self) -> None: self.root = TrieNode() self._timestamp = 0 def insert(self, word: str) -> None: """Insert a word, updating its recency timestamp.""" self._timestamp += 1 node = self.root for char in word: if char not in node.children: node.children[char] = TrieNode() node = node.children[char] node.is_end = True node.last_used = self._timestamp def search(self, word: str) -> bool: """Return True if word exists. Updates recency on access.""" node = self._find_node(word) if node is not None and node.is_end: self._timestamp += 1 node.last_used = self._timestamp return True return False def delete(self, word: str) -> bool: """Soft delete a word from the trie.""" node = self._find_node(word) if node is not None and node.is_end: node.is_end = False node.last_used = 0 return True return False def autocomplete(self, prefix: str, limit: int = 10) -> list[str]: """Return words with given prefix, ranked by recency (most recent first).""" node = self._find_node(prefix) if node is None: return [] results: list[tuple[str, int]] = [] self._collect_words(node, prefix, results) results.sort(key=lambda x: -x[1]) return [word for word, _ in results[:limit]] def _find_node(self, prefix: str) -> Optional[TrieNode]: """Traverse to the node representing the prefix.""" node = self.root for char in prefix: if char not in node.children: return None node = node.children[char] return node def _collect_words( self, node: TrieNode, prefix: str, results: list[tuple[str, int]] ) -> None: """DFS to collect all words under a node.""" if node.is_end: results.append((prefix, node.last_used)) for char, child in node.children.items(): self._collect_words(child, prefix + char, results)
Solution 3: Hybrid Ranking with Top-K Heap
A hybrid approach combining frequency and recency with a weighted score. Uses a min-heap to efficiently maintain top-k results during DFS collection, avoiding a full sort.
"""Trie Autocomplete - Solution 3: Hybrid Ranking with Top-K HeapCombines frequency and recency using a weighted score.Uses a heap to efficiently get top-k suggestions."""import heapqfrom dataclasses import dataclass, fieldfrom typing import Optional@dataclassclass TrieNode: children: dict[str, "TrieNode"] = field(default_factory=dict) is_end: bool = False frequency: int = 0 last_used: int = 0class HybridTrie: def __init__(self, recency_weight: float = 0.3) -> None: self.root = TrieNode() self._timestamp = 0 self._recency_weight = recency_weight def insert(self, word: str) -> None: """Insert a word, updating frequency and recency.""" self._timestamp += 1 node = self.root for char in word: if char not in node.children: node.children[char] = TrieNode() node = node.children[char] node.is_end = True node.frequency += 1 node.last_used = self._timestamp def search(self, word: str) -> bool: """Return True if word exists.""" node = self._find_node(word) return node is not None and node.is_end def get_frequency(self, word: str) -> int: """Return the frequency of a word, or 0 if not found.""" node = self._find_node(word) return node.frequency if node and node.is_end else 0 def autocomplete(self, prefix: str, limit: int = 10) -> list[str]: """Return top-k words ranked by hybrid score (frequency + recency).""" node = self._find_node(prefix) if node is None: return [] # Use min-heap to track top-k efficiently heap: list[tuple[float, str]] = [] self._collect_with_heap(node, prefix, heap, limit) # Extract results in descending order results = [] while heap: _, word = heapq.heappop(heap) results.append(word) return results[::-1] def _score(self, node: TrieNode) -> float: """Compute hybrid score: weighted combination of frequency and recency.""" freq_score = node.frequency recency_score = node.last_used / max(1, self._timestamp) return freq_score + self._recency_weight * recency_score def _find_node(self, prefix: str) -> Optional[TrieNode]: """Traverse to the node representing the prefix.""" node = self.root for char in prefix: if char not in node.children: return None node = node.children[char] return node def _collect_with_heap( self, node: TrieNode, prefix: str, heap: list[tuple[float, str]], limit: int ) -> None: """DFS collecting words, maintaining a min-heap of top-k.""" if node.is_end: score = self._score(node) if len(heap) < limit: heapq.heappush(heap, (score, prefix)) elif score > heap[0][0]: heapq.heapreplace(heap, (score, prefix)) for char, child in node.children.items(): self._collect_with_heap(child, prefix + char, heap, limit)
Question 11 - Chess Validator
Difficulty: 5 / 10
Approximate lines of code: 90 LoC
Tags: game/simulation
Description
A chess move validator determines whether a proposed move is legal according to the rules of chess. This involves checking piece-specific movement patterns, path blocking for sliding pieces (rook, bishop, queen), and special moves like castling and en passant. The validator does not need to track game state across moves - it validates a single move given the current board position.
The board is represented as an 8x8 grid (row 0-7, col 0-7) with pieces identified by color and type. Key data structures: a dictionary mapping positions to pieces, and optionally castling rights and en passant target square. The core logic involves computing delta (dr, dc) between start and end positions and validating against piece-specific rules.
Part A: Basic Piece Movement
Problem: Part A
Implement movement validation for all piece types. Knights move in L-shapes, bishops move diagonally, rooks move in straight lines, queens combine both, kings move one square any direction, and pawns move forward (with initial two-square option).
board = Board(squares={ (0, 4): Piece(WHITE, 'K'), # White king on e1 (1, 4): Piece(WHITE, 'P'), # White pawn on e2 (3, 3): Piece(BLACK, 'N'), # Black knight on d4})# Knight: L-shaped moves (2+1 or 1+2)validate_move(board, (3, 3), (1, 2), BLACK) # d4 to c2: Truevalidate_move(board, (3, 3), (2, 2), BLACK) # d4 to c3: False (not L-shape)# Pawn: forward one, or two from starting rowvalidate_move(board, (1, 4), (2, 4), WHITE) # e2 to e3: Truevalidate_move(board, (1, 4), (3, 4), WHITE) # e2 to e4: True (initial double)validate_move(board, (1, 4), (2, 5), WHITE) # e2 to f3: False (diagonal without capture)
Part B: Path Blocking and Captures
Problem: Part B
Sliding pieces (rook, bishop, queen) cannot jump over other pieces. Validate that the path between start and end is clear. Also validate that captures target opponent pieces only.
board = Board(squares={ (0, 0): Piece(WHITE, 'R'), # White rook on a1 (0, 3): Piece(WHITE, 'B'), # White bishop on d1 (blocking) (0, 7): Piece(BLACK, 'R'), # Black rook on h1})# Rook blocked by own piecevalidate_move(board, (0, 0), (0, 5), WHITE) # a1 to f1: False (blocked by bishop)validate_move(board, (0, 0), (0, 2), WHITE) # a1 to c1: True (path clear)# Cannot capture own piecevalidate_move(board, (0, 0), (0, 3), WHITE) # a1 to d1: False (own bishop there)# Can capture opponent pieceboard2 = Board(squares={ (0, 0): Piece(WHITE, 'R'), (0, 5): Piece(BLACK, 'P'),})validate_move(board2, (0, 0), (0, 5), WHITE) # a1 to f1: True (captures black pawn)
Part C: Special Moves
Problem: Part C
Implement castling (king moves two squares toward rook, rook jumps over), en passant (pawn captures pawn that just moved two squares), and pawn promotion (reaching last rank).
# Castling: King moves 2 squares, path must be clear, rights must existboard = Board( squares={(0, 4): Piece(WHITE, 'K'), (0, 7): Piece(WHITE, 'R')}, castling_rights={WHITE: {"kingside": True, "queenside": True}})validate_move(board, (0, 4), (0, 6), WHITE) # e1 to g1 (kingside castle): True# En passant: capture pawn that just double-movedboard = Board( squares={(4, 4): Piece(WHITE, 'P'), (4, 5): Piece(BLACK, 'P')}, en_passant_target=(5, 5) # Black pawn just moved d7-d5)validate_move(board, (4, 4), (5, 5), WHITE) # e5 to f6 (en passant): True
Interview comments
Interview comments
Edge cases to probe:
Can a pawn capture forward? (No - only diagonal captures)
What about pawn double-move when blocked on first square?
Can you castle through check? Out of check? (No to both)
En passant is only valid immediately after the double-move
Common mistakes:
Wrong pawn direction (white moves +row, black moves -row, or vice versa)
Forgetting path blocking for sliding pieces
Allowing pawn diagonal move without capture
Off-by-one in path checking (excluding start but not end, or vice versa)
Castling through pieces or when king would pass through check
Not validating that en passant target is actually capturable
Code solutions
Code solutions
Solution 1 uses an object-oriented approach with Position, Piece, and Board dataclasses, dispatching validation based on piece type. Solution 2 takes a functional approach with precomputed move tables for knights/kings and ray-based validation for sliding pieces. Solution 3 provides a compact implementation with generator-based move enumeration and includes check detection. These vary in code organization and whether moves are validated individually or enumerated as a set. Core techniques: delta-based movement patterns, ray casting for sliding pieces, box coordinate calculation.
Solution 1: Object-oriented with piece classes
Object-oriented approach with Position, Piece, and Board dataclasses. Board handles path-clear checking with directional iteration. validate_move function dispatches based on piece type.
"""Chess Move Validator - Solution 1: Object-Oriented Approach with Piece Classes"""from dataclasses import dataclassfrom typing import Optionalfrom enum import Enumclass Color(Enum): WHITE = "white" BLACK = "black"@dataclassclass Position: row: int # 0-7 col: int # 0-7 def is_valid(self) -> bool: return 0 <= self.row < 8 and 0 <= self.col < 8@dataclassclass Piece: color: Color piece_type: str # 'K', 'Q', 'R', 'B', 'N', 'P'@dataclassclass Board: squares: dict[tuple[int, int], Optional[Piece]] en_passant_target: Optional[Position] = None castling_rights: dict[Color, dict[str, bool]] = None def __post_init__(self): if self.castling_rights is None: self.castling_rights = { Color.WHITE: {"kingside": True, "queenside": True}, Color.BLACK: {"kingside": True, "queenside": True} } def get(self, pos: Position) -> Optional[Piece]: return self.squares.get((pos.row, pos.col)) def is_path_clear(self, start: Position, end: Position) -> bool: dr = 0 if end.row == start.row else (1 if end.row > start.row else -1) dc = 0 if end.col == start.col else (1 if end.col > start.col else -1) r, c = start.row + dr, start.col + dc while (r, c) != (end.row, end.col): if self.squares.get((r, c)): return False r, c = r + dr, c + dc return Truedef validate_move(board: Board, start: Position, end: Position, color: Color) -> bool: if not start.is_valid() or not end.is_valid(): return False piece = board.get(start) if not piece or piece.color != color: return False target = board.get(end) if target and target.color == color: return False dr, dc = end.row - start.row, end.col - start.col adr, adc = abs(dr), abs(dc) if piece.piece_type == 'N': return (adr, adc) in [(1, 2), (2, 1)] if piece.piece_type == 'B': return adr == adc and adr > 0 and board.is_path_clear(start, end) if piece.piece_type == 'R': return (dr == 0 or dc == 0) and (adr + adc > 0) and board.is_path_clear(start, end) if piece.piece_type == 'Q': return ((adr == adc) or (dr == 0 or dc == 0)) and (adr + adc > 0) and board.is_path_clear(start, end) if piece.piece_type == 'K': if adr <= 1 and adc <= 1 and (adr + adc > 0): return True # Castling if dr == 0 and adc == 2 and board.castling_rights[color]["kingside" if dc > 0 else "queenside"]: return board.is_path_clear(start, end) return False if piece.piece_type == 'P': direction = 1 if color == Color.WHITE else -1 start_row = 1 if color == Color.WHITE else 6 if dc == 0 and not target: if dr == direction: return True if dr == 2 * direction and start.row == start_row and board.is_path_clear(start, end): return True if adc == 1 and dr == direction: if target: return True if board.en_passant_target and end.row == board.en_passant_target.row and end.col == board.en_passant_target.col: return True return False return False
Solution 2: Functional with move tables
Functional approach with precomputed move tables for knights and kings. Uses ray-based validation for sliding pieces. Compact piece representation as strings like “wK”, “bP”.
"""Chess Move Validator - Solution 2: Functional Approach with Move Tables"""from dataclasses import dataclass, fieldfrom typing import Callable@dataclass(frozen=True)class Pos: r: int c: int def valid(self) -> bool: return 0 <= self.r < 8 and 0 <= self.c < 8 def __add__(self, other: "Pos") -> "Pos": return Pos(self.r + other.r, self.c + other.c)@dataclassclass GameState: board: dict[Pos, str] = field(default_factory=dict) # "wP", "bK", etc. en_passant: Pos | None = None castle: set[str] = field(default_factory=lambda: {"wK", "wQ", "bK", "bQ"})def color(piece: str) -> str: return piece[0]def ptype(piece: str) -> str: return piece[1]KNIGHT_MOVES = [Pos(dr, dc) for dr in [-2,-1,1,2] for dc in [-2,-1,1,2] if abs(dr) + abs(dc) == 3]KING_MOVES = [Pos(dr, dc) for dr in [-1,0,1] for dc in [-1,0,1] if (dr, dc) != (0, 0)]DIAGONALS = [Pos(d, d) for d in [-1, 1]] + [Pos(d, -d) for d in [-1, 1]]STRAIGHTS = [Pos(0, 1), Pos(0, -1), Pos(1, 0), Pos(-1, 0)]def ray_clear(state: GameState, start: Pos, direction: Pos, dist: int) -> bool: pos = start for _ in range(dist - 1): pos = pos + direction if pos in state.board: return False return Truedef validate_sliding(state: GameState, start: Pos, end: Pos, dirs: list[Pos]) -> bool: dr, dc = end.r - start.r, end.c - start.c for d in dirs: if d.r * dc == d.c * dr and (d.r == 0 or dr // d.r > 0) and (d.c == 0 or dc // d.c > 0): dist = max(abs(dr), abs(dc)) if dist > 0 and ray_clear(state, start, d, dist): return True return Falsedef validate_pawn(state: GameState, start: Pos, end: Pos, c: str) -> bool: fwd = 1 if c == 'w' else -1 home = 1 if c == 'w' else 6 dr, dc = end.r - start.r, end.c - start.c target = state.board.get(end) if dc == 0 and not target: if dr == fwd: return True if dr == 2 * fwd and start.r == home and Pos(start.r + fwd, start.c) not in state.board: return True if abs(dc) == 1 and dr == fwd: if target and color(target) != c: return True if end == state.en_passant: return True return Falsedef validate_king(state: GameState, start: Pos, end: Pos, c: str) -> bool: dr, dc = abs(end.r - start.r), abs(end.c - start.c) if dr <= 1 and dc <= 1 and (dr + dc) > 0: return True if dr == 0 and dc == 2: # Castling side = "K" if end.c > start.c else "Q" if f"{c}{side}" in state.castle and ray_clear(state, start, Pos(0, 1 if side == "K" else -1), dc): return True return Falsedef validate_move(state: GameState, start: Pos, end: Pos) -> bool: if not start.valid() or not end.valid(): return False piece = state.board.get(start) if not piece: return False target = state.board.get(end) if target and color(target) == color(piece): return False c, pt = color(piece), ptype(piece) if pt == 'N': return end - start in [(m.r, m.c) for m in KNIGHT_MOVES] or any(start + m == end for m in KNIGHT_MOVES) if pt == 'B': return validate_sliding(state, start, end, DIAGONALS) if pt == 'R': return validate_sliding(state, start, end, STRAIGHTS) if pt == 'Q': return validate_sliding(state, start, end, DIAGONALS + STRAIGHTS) if pt == 'K': return validate_king(state, start, end, c) if pt == 'P': return validate_pawn(state, start, end, c) return False
Solution 3: Compact with check detection
Compact implementation with generator-based move enumeration. Includes check detection (is_valid returns whether a move is legal, find_king locates the king, in_check determines if a color is in check). Uses minimal code with tuple-based positions.
"""Chess Move Validator - Solution 3: Compact with Check/Checkmate Detection"""from dataclasses import dataclassfrom typing import Iterator@dataclass(frozen=True)class Sq: r: int; c: int def ok(self) -> bool: return 0 <= self.r < 8 and 0 <= self.c < 8@dataclassclass Chess: board: dict[tuple[int,int], str] # (r,c) -> "wK", "bP", etc. ep: tuple[int,int] | None = None # en passant target castle: set[str] | None = None # "wK","wQ","bK","bQ" def __post_init__(self): if self.castle is None: self.castle = {"wK", "wQ", "bK", "bQ"} def at(self, r: int, c: int) -> str | None: return self.board.get((r, c)) def col(self, p: str) -> str: return p[0] def typ(self, p: str) -> str: return p[1] def ray_ok(self, r1: int, c1: int, r2: int, c2: int) -> bool: dr = (r2 > r1) - (r2 < r1); dc = (c2 > c1) - (c2 < c1) r, c = r1 + dr, c1 + dc while (r, c) != (r2, c2): if self.at(r, c): return False r, c = r + dr, c + dc return True def piece_moves(self, r: int, c: int) -> Iterator[tuple[int,int]]: p = self.at(r, c) if not p: return col, typ = self.col(p), self.typ(p) opp = 'b' if col == 'w' else 'w' if typ == 'N': for dr, dc in [(-2,-1),(-2,1),(-1,-2),(-1,2),(1,-2),(1,2),(2,-1),(2,1)]: nr, nc = r+dr, c+dc if 0 <= nr < 8 and 0 <= nc < 8: t = self.at(nr, nc) if not t or self.col(t) == opp: yield (nr, nc) elif typ in 'BRQ': dirs = ([(1,1),(1,-1),(-1,1),(-1,-1)] if typ != 'R' else []) + \ ([(0,1),(0,-1),(1,0),(-1,0)] if typ != 'B' else []) for dr, dc in dirs: nr, nc = r+dr, c+dc while 0 <= nr < 8 and 0 <= nc < 8: t = self.at(nr, nc) if t: if self.col(t) == opp: yield (nr, nc) break yield (nr, nc); nr, nc = nr+dr, nc+dc elif typ == 'K': for dr in [-1,0,1]: for dc in [-1,0,1]: if dr or dc: nr, nc = r+dr, c+dc if 0 <= nr < 8 and 0 <= nc < 8: t = self.at(nr, nc) if not t or self.col(t) == opp: yield (nr, nc) elif typ == 'P': fwd = 1 if col == 'w' else -1; home = 1 if col == 'w' else 6 if not self.at(r+fwd, c): yield (r+fwd, c) if r == home and not self.at(r+2*fwd, c): yield (r+2*fwd, c) for dc in [-1, 1]: nr, nc = r+fwd, c+dc if 0 <= nc < 8: t = self.at(nr, nc) if (t and self.col(t) == opp) or (nr, nc) == self.ep: yield (nr, nc) def is_valid(self, r1: int, c1: int, r2: int, c2: int) -> bool: return (r2, c2) in self.piece_moves(r1, c1) def find_king(self, col: str) -> tuple[int,int]: for (r, c), p in self.board.items(): if p == f"{col}K": return (r, c) raise ValueError(f"No {col} king") def in_check(self, col: str) -> bool: kr, kc = self.find_king(col) opp = 'b' if col == 'w' else 'w' for (r, c), p in self.board.items(): if self.col(p) == opp and (kr, kc) in self.piece_moves(r, c): return True return False
Question 12 - Consistent Hash
Difficulty: 10 / 10
Approximate lines of code: 80 LoC
Tags: distributed-systems, data-structures
Description
Consistent hashing is the algorithm used by distributed systems like Cassandra, DynamoDB, and memcached to distribute data across nodes with minimal remapping when nodes join or leave. The key insight: when a node is added or removed, only K/n keys need to move (where K is total keys and n is number of nodes), compared to traditional hashing where nearly all keys would remap.
The core data structure is a “hash ring” - imagine a circle from 0 to 2^32. Nodes are placed at positions determined by hashing their IDs. To find which node owns a key, hash the key and walk clockwise until you hit a node. Internal state is a sorted list of (hash_position, node_id) pairs plus a dict mapping hash positions to node names for O(1) lookup after binary search.
Note: The sortedcontainers library (SortedList, SortedDict) is available and can simplify this problem. Using SortedList instead of bisect.insort + list gives O(log n) insertions/deletions instead of O(n).
Part A: Basic Ring
Problem: Part A
Implement add_node(node_id), remove_node(node_id), and get_node(key). Use a stable hash function (MD5 or SHA - not Python’s hash() which varies between runs). Store positions in a sorted list for binary search.
ring = ConsistentHashRing()ring.add_node("server1")ring.add_node("server2")# Internal state after adding nodes:# _sorted_keys = [1847293, 2938471] # sorted hash positions# _ring = {1847293: "server1", 2938471: "server2"}ring.get_node("user:1001") # -> "server2"# hash("user:1001") = 2103847# binary search finds first position >= 2103847# returns "server2"ring.get_node("session:xyz") # -> "server1"# hash("session:xyz") = 3999999# > all positions, so wrap to first node "server1"
Part B: Virtual Nodes
Problem: Part B
With few physical nodes, keys distribute unevenly. Solution: give each physical node multiple “virtual node” positions on the ring (typically 100-200). Hash "server1#vn0", "server1#vn1", etc. to spread each server across the ring.
ring = ConsistentHashRing(num_virtual_nodes=100)ring.add_node("server1")ring.add_node("server2")# Internal state now has 200 entries (100 per server):# _sorted_keys = [12847, 38291, 49182, ...] # 200 positions# _ring = {12847: "server1", 38291: "server2", 49182: "server1", ...}# get_node() works identically - lookup returns physical node namering.get_node("key123") # -> "server1"
Part C: Efficient Lookup
Problem: Part C
Use bisect.bisect_right() for O(log n) lookup instead of linear scan. Handle the wrap-around case: when a key’s hash exceeds all node positions, wrap to index 0.
def get_node(self, key: str) -> str: hash_val = self._hash(key) idx = bisect.bisect_right(self._sorted_keys, hash_val) # Wrap around if past the end if idx == len(self._sorted_keys): idx = 0 return self._ring[self._sorted_keys[idx]]
Interview comments
Interview comments
Edge cases to probe:
What happens when you call get_node() on an empty ring?
What if you add the same node twice?
What if hash produces a collision with an existing virtual node?
How do you handle the wrap-around at the end of the ring?
Common mistakes:
Using bisect_left instead of bisect_right with wrong boundary handling
Forgetting wrap-around case (IndexError when hash > all positions)
Using Python’s hash() which is non-deterministic across runs
Not handling duplicate add_node calls gracefully
O(n) removal by iterating sorted list instead of tracking virtual node positions
Code solutions
Code solutions
Solution 1 uses a basic sorted list with bisect.insort for insertion and binary search for lookup. Solution 2 leverages SortedDict from sortedcontainers for cleaner code and adds migration tracking to report which key ranges move when nodes change. Solution 3 extends consistent hashing to support replication, returning N distinct physical nodes for fault-tolerant storage. Core techniques: binary search, hash functions (MD5/SHA), virtual nodes for load balancing.
Solution 1: Basic sorted list with bisect
Basic implementation using a sorted list with bisect.insort for O(log n) insertion and binary search for lookup. Virtual nodes are generated with a consistent naming scheme (node#vnN). Clean dataclass-based design.
"""Consistent Hashing Ring - Solution 1: Basic ImplementationUses a sorted list and binary search for node lookup."""from dataclasses import dataclass, fieldfrom typing import Optionalimport hashlibimport bisect@dataclassclass ConsistentHashRing: """A consistent hashing ring with virtual nodes for load balancing.""" num_virtual_nodes: int = 100 _ring: dict[int, str] = field(default_factory=dict) _sorted_keys: list[int] = field(default_factory=list) _nodes: set[str] = field(default_factory=set) def _hash(self, key: str) -> int: """Generate a hash position on the ring (0 to 2^32 - 1).""" return int(hashlib.md5(key.encode()).hexdigest(), 16) % (2**32) def _virtual_node_key(self, node: str, replica: int) -> str: """Generate a unique key for a virtual node.""" return f"{node}#vn{replica}" def add_node(self, node: str) -> None: """Add a node with its virtual nodes to the ring.""" if node in self._nodes: return self._nodes.add(node) for i in range(self.num_virtual_nodes): vn_key = self._virtual_node_key(node, i) hash_val = self._hash(vn_key) self._ring[hash_val] = node bisect.insort(self._sorted_keys, hash_val) def remove_node(self, node: str) -> None: """Remove a node and all its virtual nodes from the ring.""" if node not in self._nodes: return self._nodes.discard(node) for i in range(self.num_virtual_nodes): vn_key = self._virtual_node_key(node, i) hash_val = self._hash(vn_key) if hash_val in self._ring: del self._ring[hash_val] self._sorted_keys.remove(hash_val) def get_node(self, key: str) -> Optional[str]: """Find the node responsible for a given key.""" if not self._ring: return None hash_val = self._hash(key) idx = bisect.bisect_right(self._sorted_keys, hash_val) # Wrap around to the first node if past the last if idx == len(self._sorted_keys): idx = 0 return self._ring[self._sorted_keys[idx]] def get_nodes(self) -> set[str]: """Return all physical nodes in the ring.""" return self._nodes.copy()
Solution 2: SortedDict with migration tracking
Uses SortedDict from sortedcontainers for cleaner code. Adds migration tracking: add_node() and remove_node() return lists of key ranges that moved between nodes, useful for triggering data rebalancing.
"""Consistent Hashing Ring - Solution 2: Using SortedDictLeverages sortedcontainers for cleaner code and better performance.Falls back to a manual implementation if sortedcontainers unavailable."""from dataclasses import dataclass, fieldfrom typing import Optionalimport hashlibtry: from sortedcontainers import SortedDict HAS_SORTED_CONTAINERS = Trueexcept ImportError: HAS_SORTED_CONTAINERS = False@dataclassclass ConsistentHashRing: """Consistent hashing using SortedDict for O(log n) operations.""" virtual_nodes: int = 150 _ring: "SortedDict[int, str]" = field(default_factory=lambda: SortedDict()) _physical_nodes: set[str] = field(default_factory=set) def _hash(self, key: str) -> int: """Hash using SHA-256 for better distribution than MD5.""" digest = hashlib.sha256(key.encode()).hexdigest() return int(digest[:16], 16) # Use first 64 bits def add_node(self, node: str) -> list[tuple[str, str, str]]: """Add node, return list of (key_range_start, old_node, new_node) migrations.""" if node in self._physical_nodes: return [] self._physical_nodes.add(node) migrations = [] for i in range(self.virtual_nodes): vnode_hash = self._hash(f"{node}:{i}") if vnode_hash in self._ring: continue # Hash collision, skip # Find what node previously owned this position idx = self._ring.bisect_right(vnode_hash) if idx < len(self._ring): old_owner = self._ring.values()[idx] migrations.append((str(vnode_hash), old_owner, node)) self._ring[vnode_hash] = node return migrations def remove_node(self, node: str) -> list[tuple[str, str, str]]: """Remove node, return migrations to successor nodes.""" if node not in self._physical_nodes: return [] self._physical_nodes.discard(node) migrations = [] keys_to_remove = [] for hash_val, owner in self._ring.items(): if owner == node: keys_to_remove.append(hash_val) for hash_val in keys_to_remove: del self._ring[hash_val] # Find new owner idx = self._ring.bisect_right(hash_val) if len(self._ring) > 0: new_owner = self._ring.values()[idx % len(self._ring)] migrations.append((str(hash_val), node, new_owner)) return migrations def get_node(self, key: str) -> Optional[str]: """Get the node responsible for this key.""" if not self._ring: return None key_hash = self._hash(key) idx = self._ring.bisect_right(key_hash) if idx == len(self._ring): idx = 0 return self._ring.values()[idx] def get_node_count(self, sample_keys: int = 10000) -> dict[str, int]: """Sample key distribution across nodes.""" counts: dict[str, int] = {node: 0 for node in self._physical_nodes} for i in range(sample_keys): node = self.get_node(f"sample_key_{i}") if node: counts[node] += 1 return counts
Solution 3: Replication support
Extends basic consistent hashing to support replication. get_nodes_for_key(key, count=N) returns N distinct physical nodes by walking clockwise and skipping duplicate physical nodes (from virtual nodes). Essential for fault-tolerant storage where you want multiple replicas.
"""Consistent Hashing Ring - Solution 3: With Replication SupportExtends basic consistent hashing to return N replicas for fault tolerance."""from dataclasses import dataclass, fieldfrom typing import Optionalimport hashlibimport bisect@dataclassclass ReplicatedHashRing: """Consistent hash ring that supports fetching multiple replica nodes.""" virtual_nodes: int = 100 default_replicas: int = 3 _ring: dict[int, str] = field(default_factory=dict) _sorted_hashes: list[int] = field(default_factory=list) _nodes: set[str] = field(default_factory=set) def _hash(self, key: str) -> int: """Compute hash using xxhash-style mixing for speed simulation.""" # Using SHA-1 truncated; in production use xxhash or similar return int(hashlib.sha1(key.encode()).hexdigest()[:8], 16) def add_node(self, node: str) -> None: """Add a physical node with virtual replicas.""" if node in self._nodes: return self._nodes.add(node) for i in range(self.virtual_nodes): h = self._hash(f"{node}:vn:{i}") if h not in self._ring: # Avoid collisions self._ring[h] = node bisect.insort(self._sorted_hashes, h) def remove_node(self, node: str) -> None: """Remove a node and its virtual nodes.""" if node not in self._nodes: return self._nodes.discard(node) to_remove = [h for h, n in self._ring.items() if n == node] for h in to_remove: del self._ring[h] self._sorted_hashes.remove(h) def get_node(self, key: str) -> Optional[str]: """Get single node for a key.""" nodes = self.get_nodes_for_key(key, count=1) return nodes[0] if nodes else None def get_nodes_for_key(self, key: str, count: Optional[int] = None) -> list[str]: """Get N distinct physical nodes for replication.""" if not self._ring: return [] count = count or self.default_replicas count = min(count, len(self._nodes)) # Can't exceed physical nodes key_hash = self._hash(key) start_idx = bisect.bisect_right(self._sorted_hashes, key_hash) result: list[str] = [] seen: set[str] = set() for i in range(len(self._sorted_hashes)): idx = (start_idx + i) % len(self._sorted_hashes) node = self._ring[self._sorted_hashes[idx]] if node not in seen: result.append(node) seen.add(node) if len(result) == count: break return result def get_ring_state(self) -> list[tuple[int, str]]: """Debug: return sorted (hash, node) pairs.""" return [(h, self._ring[h]) for h in self._sorted_hashes]
Question 13 - DNS Cache
Difficulty: 4 / 10
Approximate lines of code: 70 LoC
Tags: storage
Description
A DNS cache stores domain-to-IP mappings to avoid repeated network lookups. Each record has a TTL (time-to-live) after which it expires. The cache must handle three things: TTL-based expiration, capacity limits (evicting old entries when full), and CNAME chains (aliases that point to other domains which may themselves be aliases).
Internal state is typically a dict mapping domain names to cache entries. Each entry contains the IP address (or CNAME target), the expiration timestamp, and optionally the record type. For LRU eviction, use an OrderedDict where accessing an entry moves it to the end, and eviction removes from the front.
Part A: TTL-Based Caching
Problem: Part A
Implement basic DNS caching with time-to-live expiration. resolve(domain) checks the cache first, falling back to a recursive lookup on miss or expiration.
cache = DNSCache(default_ttl=300) # 5 minute TTL# First lookup - cache miss, performs recursive resolutionip = cache.resolve("example.com")# cache: {"example.com": CacheEntry(ip="93.184.216.34", expires_at=now+300)}# Second lookup - cache hitip = cache.resolve("example.com") # Returns instantly from cache# Manual record insertion with custom TTLcache.add_record("internal.corp", "10.0.0.1", ttl=60)# After TTL expires, entry is staletime.sleep(301)cache.resolve("example.com") # Cache miss - expired entry removed, fresh lookup
Part B: LRU Eviction
Problem: Part B
Add a maximum cache size. When inserting into a full cache, evict the least recently used entry. Accessing an entry (even a cache hit) should update its recency.
cache = LRUDNSCache(max_size=3, default_ttl=300)cache.set_record("a.com", "1.1.1.1")cache.set_record("b.com", "2.2.2.2")cache.set_record("c.com", "3.3.3.3")# Cache order (LRU to MRU): [a.com, b.com, c.com]cache.resolve("a.com") # Access a.com, moves to MRU position# Cache order: [b.com, c.com, a.com]cache.set_record("d.com", "4.4.4.4") # Cache full, evict LRU (b.com)# Cache order: [c.com, a.com, d.com]# b.com is evictedassert "b.com" not in cache # Evictedassert "a.com" in cache # Kept (was recently accessed)
Part C: CNAME Chain Resolution
Problem: Part C
Handle CNAME (canonical name) records that alias one domain to another. When resolving, follow the chain until you reach an A record. Protect against infinite loops.
# DNS data:# www.example.com -> CNAME -> example.com# blog.example.com -> CNAME -> www.example.com# example.com -> A -> 93.184.216.34cache = HierarchicalDNSCache(max_cname_depth=10)ip = cache.resolve("blog.example.com")# 1. Lookup blog.example.com -> CNAME www.example.com# 2. Lookup www.example.com -> CNAME example.com# 3. Lookup example.com -> A 93.184.216.34# Returns: 93.184.216.34# All records are cached:# (blog.example.com, CNAME) -> www.example.com# (www.example.com, CNAME) -> example.com# (example.com, A) -> 93.184.216.34# Loop protection:# evil.com -> CNAME -> evil.com (would loop forever)cache.resolve("evil.com") # Returns None after max_cname_depth iterations
Interview comments
Interview comments
Edge cases to probe:
What happens when resolving a domain that doesn’t exist?
How do you handle CNAME loops (A → B → A)?
What if a CNAME points to another CNAME with a shorter TTL?
How do you evict entries - lazily on access or proactively?
Common mistakes:
Not removing expired entries from cache (just returning None isn’t enough)
Infinite loop on CNAME chains without depth limiting
Using system clock that can jump (use monotonic time for TTL calculations)
Storing CNAME target IP instead of following the chain
Code solutions
Code solutions
Solution 1 is a simple TTL-based cache using dict storage with lazy expiration (entries removed on access after expiry). Solution 2 adds LRU eviction with a max capacity using OrderedDict for O(1) LRU operations via move_to_end() and popitem(). Solution 3 handles hierarchical DNS with CNAME chain resolution, using (domain, record_type) tuple keys and depth limiting to prevent infinite loops. These vary in their eviction strategy and DNS complexity: Solution 1 handles basic TTL, Solution 2 adds capacity management, and Solution 3 models realistic DNS behavior. Core techniques: TTL-based expiration, OrderedDict for LRU, CNAME chain following with depth limiting, lazy vs proactive eviction.
Solution 1: Simple TTL-based Cache
Simple TTL-based cache with dict storage. Stores expiration timestamp directly in each entry. Lazy expiration - entries are removed when accessed after expiry. Includes manual evict_expired() for batch cleanup.
"""DNS Cache Solution 1: Simple TTL-based cache with dict storage.Focuses on correctness and simplicity."""from dataclasses import dataclass, fieldfrom typing import Optionalimport time@dataclassclass CacheEntry: ip_address: str expires_at: float@dataclassclass DNSCache: cache: dict[str, CacheEntry] = field(default_factory=dict) default_ttl: int = 300 # 5 minutes def resolve(self, domain: str) -> Optional[str]: """Resolve domain to IP, checking cache first.""" entry = self.cache.get(domain) if entry and entry.expires_at > time.time(): return entry.ip_address # Cache miss or expired - do recursive resolution if entry: del self.cache[domain] ip = self._recursive_resolve(domain) if ip: self.cache[domain] = CacheEntry( ip_address=ip, expires_at=time.time() + self.default_ttl ) return ip def _recursive_resolve(self, domain: str) -> Optional[str]: """Simulate recursive DNS resolution.""" # In real implementation, this would query DNS servers fake_dns = { "example.com": "93.184.216.34", "google.com": "142.250.80.46", "github.com": "140.82.114.4", } return fake_dns.get(domain) def add_record(self, domain: str, ip: str, ttl: Optional[int] = None) -> None: """Manually add a DNS record to cache.""" self.cache[domain] = CacheEntry( ip_address=ip, expires_at=time.time() + (ttl or self.default_ttl) ) def evict_expired(self) -> int: """Remove all expired entries. Returns count of evicted entries.""" now = time.time() expired = [d for d, e in self.cache.items() if e.expires_at <= now] for domain in expired: del self.cache[domain] return len(expired)
Solution 2: LRU Eviction with Max Capacity
LRU eviction with max capacity using OrderedDict. move_to_end() on access maintains LRU ordering. popitem(last=False) evicts the oldest entry. Combines TTL expiration with capacity-based eviction.
"""DNS Cache Solution 2: LRU eviction with max capacity.Uses OrderedDict for O(1) LRU operations."""from dataclasses import dataclass, fieldfrom collections import OrderedDictfrom typing import Optionalimport time@dataclassclass DNSRecord: ip_address: str ttl: int created_at: float def is_expired(self) -> bool: return time.time() > self.created_at + self.ttl@dataclassclass LRUDNSCache: max_size: int = 1000 default_ttl: int = 300 _cache: OrderedDict[str, DNSRecord] = field(default_factory=OrderedDict) def resolve(self, domain: str) -> Optional[str]: """Resolve with LRU tracking.""" if domain in self._cache: record = self._cache[domain] if record.is_expired(): del self._cache[domain] else: # Move to end (most recently used) self._cache.move_to_end(domain) return record.ip_address # Cache miss - resolve and cache ip = self._recursive_resolve(domain) if ip: self._add_to_cache(domain, ip, self.default_ttl) return ip def _add_to_cache(self, domain: str, ip: str, ttl: int) -> None: """Add entry, evicting LRU if at capacity.""" if domain in self._cache: del self._cache[domain] elif len(self._cache) >= self.max_size: self._cache.popitem(last=False) # Remove oldest (LRU) self._cache[domain] = DNSRecord( ip_address=ip, ttl=ttl, created_at=time.time() ) def _recursive_resolve(self, domain: str) -> Optional[str]: """Simulate DNS lookup.""" records = { "example.com": "93.184.216.34", "google.com": "142.250.80.46", "github.com": "140.82.114.4", } return records.get(domain) def set_record(self, domain: str, ip: str, ttl: Optional[int] = None) -> None: """Manually set a DNS record.""" self._add_to_cache(domain, ip, ttl or self.default_ttl) def size(self) -> int: return len(self._cache)
Solution 3: Hierarchical DNS with CNAME Chains
Hierarchical DNS with CNAME chain resolution. Cache keys are (domain, record_type) tuples to store both A and CNAME records. Recursive resolution with depth limiting prevents infinite loops. Follows CNAME chains until reaching an A record.
"""DNS Cache Solution 3: Hierarchical DNS with multiple record types.More realistic simulation of DNS resolution chain."""from dataclasses import dataclass, fieldfrom typing import Optionalfrom enum import Enumimport timeclass RecordType(Enum): A = "A" # IPv4 address CNAME = "CNAME" # Canonical name (alias)@dataclassclass DNSRecord: record_type: RecordType value: str ttl: int expires_at: float@dataclassclass HierarchicalDNSCache: cache: dict[tuple[str, RecordType], DNSRecord] = field(default_factory=dict) default_ttl: int = 300 max_cname_depth: int = 10 # Prevent infinite loops def resolve(self, domain: str) -> Optional[str]: """Resolve domain, following CNAME chains.""" return self._resolve_with_depth(domain, 0) def _resolve_with_depth(self, domain: str, depth: int) -> Optional[str]: if depth > self.max_cname_depth: return None # CNAME loop protection # Check for A record first a_record = self._get_cached(domain, RecordType.A) if a_record: return a_record.value # Check for CNAME and follow it cname = self._get_cached(domain, RecordType.CNAME) if cname: return self._resolve_with_depth(cname.value, depth + 1) # Cache miss - do recursive lookup record = self._recursive_lookup(domain) if not record: return None self._cache_record(domain, record) if record.record_type == RecordType.A: return record.value else: # CNAME - follow the chain return self._resolve_with_depth(record.value, depth + 1) def _get_cached(self, domain: str, rtype: RecordType) -> Optional[DNSRecord]: key = (domain, rtype) record = self.cache.get(key) if record and record.expires_at > time.time(): return record if record: del self.cache[key] return None def _cache_record(self, domain: str, record: DNSRecord) -> None: self.cache[(domain, record.record_type)] = record def _recursive_lookup(self, domain: str) -> Optional[DNSRecord]: """Simulate authoritative DNS response.""" dns_data = { "example.com": (RecordType.A, "93.184.216.34"), "www.example.com": (RecordType.CNAME, "example.com"), "blog.example.com": (RecordType.CNAME, "www.example.com"), "google.com": (RecordType.A, "142.250.80.46"), } if domain not in dns_data: return None rtype, value = dns_data[domain] return DNSRecord( record_type=rtype, value=value, ttl=self.default_ttl, expires_at=time.time() + self.default_ttl )
Question 14 - Pub/Sub Message Broker
Difficulty: 7 / 10
Approximate lines of code: 90 LoC
Tags: distributed-systems
Description
A publish-subscribe message broker decouples message producers from consumers. Publishers send messages to topics without knowing who will receive them; subscribers register interest in topics and receive matching messages. This pattern is fundamental to event-driven architectures, microservices communication, and real-time systems.
There are two main delivery models: push (broker calls subscriber callbacks immediately) and pull (subscribers request messages on demand). Each message gets a sequence number for ordering. For reliability, the broker may retain messages so disconnected subscribers can catch up. Consumer groups add load balancing - within a group, each message goes to only one member.
Part A: Basic Pub/Sub
Problem: Part A
Implement create_topic(), subscribe(topic, subscriber_id, callback), and publish(topic, payload). Messages are pushed to subscriber callbacks immediately.
Add sequence numbers and message retention. Subscribers can reconnect and catch up from where they left off (pull model).
broker = PubSubBroker()broker.create_topic("events")broker.subscribe("events", "sub1")broker.publish("events", "msg1") # seq=1broker.publish("events", "msg2") # seq=2broker.publish("events", "msg3") # seq=3# Pull model - subscriber controls when to receivemsgs = broker.pull("events", "sub1", max_messages=2)# Returns [msg1, msg2], updates sub1's last_read to 2print(broker.get_pending_count("events", "sub1")) # 1 (msg3 pending)# After disconnect/reconnect, subscriber catches up:broker.disconnect("sub1")broker.publish("events", "msg4") # Subscriber is offlinebroker.reconnect("sub1")msgs = broker.pull("events", "sub1") # Gets msg3 and msg4
Part C: Consumer Groups
Problem: Part C
Within a consumer group, each message is delivered to exactly one member (round-robin load balancing). Individual subscribers still receive all messages.
broker = PubSubBroker()broker.create_topic("orders")# Individual subscriber gets ALL messagesmonitor_msgs = []broker.subscribe("monitor", "orders", lambda m: monitor_msgs.append(m))# Consumer group - messages load balancedworker1_msgs = []worker2_msgs = []broker.join_group("worker1", "orders", "processors", lambda m: worker1_msgs.append(m))broker.join_group("worker2", "orders", "processors", lambda m: worker2_msgs.append(m))for i in range(4): broker.publish("orders", f"order_{i}")print(len(monitor_msgs)) # 4 - monitor gets allprint(len(worker1_msgs) + len(worker2_msgs)) # 4 - split between workers# Each worker gets ~2 messages (round-robin)
Interview comments
Interview comments
Edge cases to probe:
What happens if you publish to a non-existent topic?
What if a subscriber callback raises an exception?
How do you handle unsubscribe during iteration?
What’s the memory impact of unbounded message retention?
Common mistakes:
Mutating subscriber list while iterating (ConcurrentModificationException)
Not handling disconnected subscribers (should skip, not crash)
Memory leaks from infinite message retention (need TTL or compaction)
Sequence numbers resetting on restart (need persistence for durability)
Code solutions
Code solutions
Solutions Overview
Solution 1 is a simple push model that immediately invokes subscriber callbacks on publish. Solution 2 is a pull model with message retention and read position tracking per subscriber. Solution 3 is a hybrid push/pull with consumer groups for load-balanced delivery. These vary in their delivery model (push vs pull) and subscriber grouping. Core techniques: callback dispatch, sequence number tracking, round-robin load balancing.
Solution 1: Simple push model
Simple push model with callbacks. Messages stored per topic with sequence numbers. Subscribers can connect/disconnect. Immediately invokes callbacks on publish.
"""Pub/Sub Message Broker - Solution 1: Simple In-Memory with Push ModelBasic implementation using callbacks for message delivery.Subscribers register callbacks that are invoked when messages are published."""from dataclasses import dataclass, fieldfrom typing import Callable, Dict, List, Setfrom collections import defaultdict@dataclassclass Message: topic: str payload: str sequence: int = 0@dataclassclass Subscriber: id: str callback: Callable[[Message], None] connected: bool = True@dataclassclass PubSubBroker: _topics: Dict[str, List[Message]] = field(default_factory=lambda: defaultdict(list)) _subscribers: Dict[str, Dict[str, Subscriber]] = field(default_factory=lambda: defaultdict(dict)) _sequence: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) def create_topic(self, topic: str) -> None: if topic not in self._topics: self._topics[topic] = [] self._sequence[topic] = 0 def subscribe(self, topic: str, subscriber_id: str, callback: Callable[[Message], None]) -> bool: if topic not in self._topics: return False self._subscribers[topic][subscriber_id] = Subscriber(id=subscriber_id, callback=callback) return True def unsubscribe(self, topic: str, subscriber_id: str) -> bool: if topic in self._subscribers and subscriber_id in self._subscribers[topic]: del self._subscribers[topic][subscriber_id] return True return False def disconnect(self, subscriber_id: str) -> None: for topic_subs in self._subscribers.values(): if subscriber_id in topic_subs: topic_subs[subscriber_id].connected = False def reconnect(self, subscriber_id: str) -> None: for topic_subs in self._subscribers.values(): if subscriber_id in topic_subs: topic_subs[subscriber_id].connected = True def publish(self, topic: str, payload: str) -> int: if topic not in self._topics: return -1 self._sequence[topic] += 1 msg = Message(topic=topic, payload=payload, sequence=self._sequence[topic]) self._topics[topic].append(msg) # Push to connected subscribers for sub in self._subscribers[topic].values(): if sub.connected: sub.callback(msg) return msg.sequence
Solution 2: Pull model with retention
Pull model with message retention. Tracks read position per subscriber. Supports catching up on missed messages after disconnect/reconnect. Includes pending count tracking.
"""Pub/Sub Message Broker - Solution 2: Pull Model with Message RetentionSubscribers pull messages on demand. Tracks read position per subscriber.Supports catching up on missed messages after reconnect."""from dataclasses import dataclass, fieldfrom typing import Dict, List, Optionalfrom collections import defaultdict@dataclassclass Message: topic: str payload: str sequence: int@dataclassclass SubscriberState: id: str last_read: int = 0 # Last sequence number read connected: bool = True@dataclassclass PubSubBroker: _messages: Dict[str, List[Message]] = field(default_factory=lambda: defaultdict(list)) _subscribers: Dict[str, Dict[str, SubscriberState]] = field(default_factory=lambda: defaultdict(dict)) _sequence: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) def create_topic(self, topic: str) -> None: if topic not in self._messages: self._messages[topic] = [] def subscribe(self, topic: str, subscriber_id: str, from_beginning: bool = False) -> bool: if topic not in self._messages: return False start_pos = 0 if from_beginning else self._sequence[topic] self._subscribers[topic][subscriber_id] = SubscriberState(id=subscriber_id, last_read=start_pos) return True def unsubscribe(self, topic: str, subscriber_id: str) -> bool: if topic in self._subscribers and subscriber_id in self._subscribers[topic]: del self._subscribers[topic][subscriber_id] return True return False def disconnect(self, subscriber_id: str) -> None: for topic_subs in self._subscribers.values(): if subscriber_id in topic_subs: topic_subs[subscriber_id].connected = False def reconnect(self, subscriber_id: str) -> None: for topic_subs in self._subscribers.values(): if subscriber_id in topic_subs: topic_subs[subscriber_id].connected = True def publish(self, topic: str, payload: str) -> int: if topic not in self._messages: return -1 self._sequence[topic] += 1 msg = Message(topic=topic, payload=payload, sequence=self._sequence[topic]) self._messages[topic].append(msg) return msg.sequence def pull(self, topic: str, subscriber_id: str, max_messages: int = 10) -> List[Message]: if topic not in self._subscribers or subscriber_id not in self._subscribers[topic]: return [] state = self._subscribers[topic][subscriber_id] if not state.connected: return [] # Find messages after last_read result = [] for msg in self._messages[topic]: if msg.sequence > state.last_read: result.append(msg) if len(result) >= max_messages: break if result: state.last_read = result[-1].sequence return result def get_pending_count(self, topic: str, subscriber_id: str) -> int: if topic not in self._subscribers or subscriber_id not in self._subscribers[topic]: return 0 state = self._subscribers[topic][subscriber_id] return self._sequence[topic] - state.last_read
Solution 3: Hybrid with consumer groups
Hybrid push/pull with consumer groups. Individual subscribers get all messages; consumer group members get round-robin distribution. Includes message acknowledgment for reliable delivery.
"""Pub/Sub Message Broker - Solution 3: Hybrid Push/Pull with Consumer GroupsFeatures:- Consumer groups for load balancing (each message delivered to one group member)- Individual subscribers get all messages- Message acknowledgment for reliable delivery"""from dataclasses import dataclass, fieldfrom typing import Callable, Dict, List, Optional, Setfrom collections import defaultdictimport itertools@dataclassclass Message: topic: str payload: str sequence: int@dataclassclass ConsumerGroup: name: str members: List[str] = field(default_factory=list) pending: Dict[int, str] = field(default_factory=dict) # seq -> assigned_member last_delivered: int = 0 _round_robin: int = 0 def next_member(self) -> Optional[str]: connected = [m for m in self.members] if not connected: return None self._round_robin = (self._round_robin + 1) % len(connected) return connected[self._round_robin]@dataclassclass PubSubBroker: _messages: Dict[str, List[Message]] = field(default_factory=lambda: defaultdict(list)) _individual_subs: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) _consumer_groups: Dict[str, Dict[str, ConsumerGroup]] = field(default_factory=lambda: defaultdict(dict)) _callbacks: Dict[str, Callable[[Message], None]] = field(default_factory=dict) _sequence: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) def create_topic(self, topic: str) -> None: if topic not in self._messages: self._messages[topic] = [] def subscribe(self, subscriber_id: str, topic: str, callback: Callable[[Message], None]) -> bool: if topic not in self._messages: return False self._individual_subs[topic].add(subscriber_id) self._callbacks[subscriber_id] = callback return True def join_group(self, subscriber_id: str, topic: str, group_name: str, callback: Callable[[Message], None]) -> bool: if topic not in self._messages: return False if group_name not in self._consumer_groups[topic]: self._consumer_groups[topic][group_name] = ConsumerGroup(name=group_name) group = self._consumer_groups[topic][group_name] if subscriber_id not in group.members: group.members.append(subscriber_id) self._callbacks[subscriber_id] = callback return True def leave_group(self, subscriber_id: str, topic: str, group_name: str) -> bool: if group_name in self._consumer_groups[topic]: group = self._consumer_groups[topic][group_name] if subscriber_id in group.members: group.members.remove(subscriber_id) return True return False def publish(self, topic: str, payload: str) -> int: if topic not in self._messages: return -1 self._sequence[topic] += 1 msg = Message(topic=topic, payload=payload, sequence=self._sequence[topic]) self._messages[topic].append(msg) # Push to individual subscribers for sub_id in self._individual_subs[topic]: if sub_id in self._callbacks: self._callbacks[sub_id](msg) # Push to one member per consumer group (round-robin) for group in self._consumer_groups[topic].values(): member = group.next_member() if member and member in self._callbacks: group.pending[msg.sequence] = member self._callbacks[member](msg) return msg.sequence def ack(self, topic: str, group_name: str, sequence: int) -> bool: if group_name in self._consumer_groups[topic]: group = self._consumer_groups[topic][group_name] if sequence in group.pending: del group.pending[sequence] return True return False
Question 15 - Reader-Writer Lock
Difficulty: 9 / 10
Approximate lines of code: 60 LoC
Tags: concurrency
Description
Concurrency Cheat Sheet
This problem involves concurrency. Here is a cheat sheet of relevant Python syntax. You may not need all of these; this list is not exhaustive or suggestive for the problem.
# threading.Threadt = threading.Thread(target=func, args=(arg1, arg2))t.start()t.join()# threading.Locklock = threading.Lock()lock.acquire()lock.release()with lock: # critical section# threading.Conditioncond = threading.Condition(lock) # or threading.Condition()with cond: cond.wait() # releases lock, waits, re-acquires cond.notify() # wake one waiter cond.notify_all() # wake all waiters# threading.Semaphoresem = threading.Semaphore(value=1)sem.acquire()sem.release()# threading.Eventevent = threading.Event()event.set() # set flag to Trueevent.clear() # set flag to Falseevent.wait() # block until flag is Trueevent.is_set() # check flag# threading.RLock (reentrant)rlock = threading.RLock()rlock.acquire() # can acquire multiple times from same threadrlock.release() # must release same number of times
A reader-writer lock (RWLock) allows concurrent read access but exclusive write access. Multiple readers can hold the lock simultaneously since reads don’t conflict, but writers need exclusive access to prevent data races. This is useful for shared data structures where reads vastly outnumber writes, like caches or configuration stores.
The basic invariants are: (1) if a writer holds the lock, no readers or other writers can hold it, and (2) if readers hold the lock, any number of additional readers can acquire it, but writers must wait. The tricky part is preventing writer starvation: without care, a steady stream of readers can indefinitely postpone waiting writers.
Part A: Basic Lock
Problem: Part A
Implement acquire_read(), release_read(), acquire_write(), and release_write(). Use a lock and condition variable(s) to coordinate. Track the number of active readers and whether a writer holds the lock.
lock = ReaderWriterLock()# Thread 1 (Reader) Thread 2 (Reader) Thread 3 (Writer)lock.acquire_read() lock.acquire_read()# Both readers active # Both readers active# _readers = 2 # _readers = 2 lock.acquire_write() # BLOCKS # Waiting for _readers == 0lock.release_read()# _readers = 1 # Still blocked lock.release_read() # _readers = 0 # Writer unblocked! # _writer = True, exclusive access# Internal state during writer hold:# _readers = 0, _writer = True# Any acquire_read() or acquire_write() blocks
Part B: Starvation Prevention
Problem: Part B
The basic implementation has a problem: if new readers keep arriving, a waiting writer never gets in. Fix this by tracking waiting writers and blocking new readers when a writer is waiting.
lock = ReaderWriterLock() # With writer preference# Thread 1-5 are readers, Thread 6 is a writer# Time 0: Readers 1,2,3 acquire read locks# _readers = 3, _waiting_writers = 0# Time 1: Writer 6 calls acquire_write()# _waiting_writers = 1, writer blocks waiting for _readers == 0# Time 2: Reader 4 calls acquire_read()# With writer preference: blocks because _waiting_writers > 0# Without: would acquire immediately, potentially starving writer# Time 3: Readers 1,2,3 release# _readers = 0, writer 6 acquires# _waiting_writers = 0, _writer = True# Time 4: Writer 6 releases# Reader 4 can now acquire# Key invariant for writer preference:# acquire_read blocks while: _writer OR _waiting_writers > 0
Part C: Upgradable Lock
Problem: Part C
Add an “upgradable read” mode: a reader that can later upgrade to a writer without releasing the lock. Only one thread can hold an upgradable lock at a time (if two could, they’d deadlock trying to upgrade simultaneously). The upgrade waits for other readers to release.
lock = UpgradableRWLock()# Thread 1 (Upgradable) Thread 2 (Regular Reader)lock.acquire_upgradable()# Counts as a reader: _readers = 1# _upgradable_held = True lock.acquire_read() # _readers = 2# Thread 1 decides to write:lock.upgrade()# Sets _upgrade_waiting = True (blocks new readers)# Decrements _readers to 1 (self)# Waits for _readers == 0 lock.release_read() # _readers = 0# Thread 1 continues:# _writer = True, _upgrade_waiting = False# Now has exclusive write accesslock.release_write() # Not release_upgradable!# _writer = False, _upgradable_held = False# Key: only ONE upgradable lock at a timelock2.acquire_upgradable() # Thread 3 trieslock2.acquire_upgradable() # Thread 4 blocks until Thread 3 releases
Interview comments
Interview comments
Edge cases to probe:
What happens if a thread releases a lock it doesn’t hold?
How do you prevent new readers from starving waiting writers?
Why can only one thread hold an upgradable lock?
What if a thread tries to upgrade without holding an upgradable lock?
Common mistakes:
Race condition in first-reader/last-reader logic (count check and lock acquisition not atomic)
Using notify() instead of notify_all() when releasing write lock (multiple readers waiting)
Deadlock in upgrade: still counting self as reader while waiting for readers == 0
Writer starvation: not blocking new readers when writers are waiting
Forgetting to decrement reader count during upgrade before waiting
Code solutions
Code solutions
Solution 1 uses condition variables with separate _can_read and _can_write conditions, tracking _waiting_writers to implement writer preference. Solution 2 uses a semaphore-based turnstile pattern where writers hold the turnstile to block new readers. Solution 3 adds upgradable read locks that can be promoted to write locks without releasing. The key difference is the synchronization mechanism and whether they support lock upgrades.
Solution 1: Condition variable based
Condition variable based implementation with separate conditions for readers and writers. Uses _waiting_writers count to implement writer preference and prevent starvation.
"""Reader-Writer Lock - Solution 1: Condition Variable BasedUses a condition variable with reader/writer counts and a waiting writer flagto prevent writer starvation."""from __future__ import annotationsimport threadingfrom dataclasses import dataclass, fieldfrom typing import Optional@dataclassclass ReaderWriterLock: """Reader-writer lock preventing writer starvation via waiting_writers count.""" _lock: threading.Lock = field(default_factory=threading.Lock) _can_read: threading.Condition = field(init=False) _can_write: threading.Condition = field(init=False) _readers: int = 0 _writers: int = 0 _waiting_writers: int = 0 def __post_init__(self) -> None: self._can_read = threading.Condition(self._lock) self._can_write = threading.Condition(self._lock) def acquire_read(self) -> None: """Acquire read lock. Blocks if writer holds lock or writers are waiting.""" with self._can_read: while self._writers > 0 or self._waiting_writers > 0: self._can_read.wait() self._readers += 1 def release_read(self) -> None: """Release read lock. Notifies waiting writers if last reader.""" with self._can_write: self._readers -= 1 if self._readers == 0: self._can_write.notify() def acquire_write(self) -> None: """Acquire write lock. Blocks until no readers or writers.""" with self._can_write: self._waiting_writers += 1 while self._readers > 0 or self._writers > 0: self._can_write.wait() self._waiting_writers -= 1 self._writers += 1 def release_write(self) -> None: """Release write lock. Notifies all waiting readers and one writer.""" with self._lock: self._writers -= 1 self._can_read.notify_all() self._can_write.notify()
Solution 2: Semaphore based with fair queuing
Semaphore-based implementation using a turnstile pattern. Writers hold the turnstile to block new readers. Uses room_empty semaphore for writer exclusivity.
"""Reader-Writer Lock - Solution 2: Semaphore Based with Fair QueuingUses a turnstile semaphore to ensure fairness between readers and writers.Writers block the turnstile, preventing new readers from cutting ahead."""from __future__ import annotationsimport threadingfrom dataclasses import dataclass, field@dataclassclass FairReaderWriterLock: """Fair reader-writer lock using semaphores with turnstile pattern.""" _turnstile: threading.Semaphore = field(default_factory=lambda: threading.Semaphore(1)) _room_empty: threading.Semaphore = field(default_factory=lambda: threading.Semaphore(1)) _reader_lock: threading.Lock = field(default_factory=threading.Lock) _reader_count: int = 0 def acquire_read(self) -> None: """Acquire read lock. Goes through turnstile for fairness.""" self._turnstile.acquire() self._turnstile.release() with self._reader_lock: self._reader_count += 1 if self._reader_count == 1: self._room_empty.acquire() def release_read(self) -> None: """Release read lock. Last reader signals room is empty.""" with self._reader_lock: self._reader_count -= 1 if self._reader_count == 0: self._room_empty.release() def acquire_write(self) -> None: """Acquire write lock. Holds turnstile to block new readers.""" self._turnstile.acquire() self._room_empty.acquire() def release_write(self) -> None: """Release write lock.""" self._room_empty.release() self._turnstile.release()
Solution 3: With upgradable read lock
Full implementation with upgradable read locks. Only one thread can hold an upgradable lock. Upgrade operation decrements reader count, sets waiting flag, then waits for other readers.
"""Reader-Writer Lock - Solution 3: With Upgradable Read LockSupports upgrading a read lock to a write lock without releasing.Only one thread can hold an upgradable lock at a time."""from __future__ import annotationsimport threadingfrom dataclasses import dataclass, field@dataclassclass UpgradableRWLock: """Reader-writer lock supporting upgradable read locks.""" _lock: threading.Lock = field(default_factory=threading.Lock) _cond: threading.Condition = field(init=False) _readers: int = 0 _writer: bool = False _upgradable_held: bool = False _upgrade_waiting: bool = False def __post_init__(self) -> None: self._cond = threading.Condition(self._lock) def acquire_read(self) -> None: """Acquire shared read lock.""" with self._cond: while self._writer or self._upgrade_waiting: self._cond.wait() self._readers += 1 def release_read(self) -> None: """Release shared read lock.""" with self._cond: self._readers -= 1 if self._readers == 0: self._cond.notify_all() def release_upgradable(self) -> None: """Release upgradable read lock without upgrading.""" with self._cond: self._readers -= 1 self._upgradable_held = False if self._readers == 0: self._cond.notify_all() def acquire_upgradable(self) -> None: """Acquire upgradable read lock. Only one thread can hold this.""" with self._cond: while self._writer or self._upgradable_held: self._cond.wait() self._upgradable_held = True self._readers += 1 def upgrade(self) -> None: """Upgrade from upgradable read lock to write lock.""" with self._cond: self._upgrade_waiting = True self._readers -= 1 while self._readers > 0: self._cond.wait() self._upgrade_waiting = False self._writer = True def release_write(self) -> None: """Release write lock (after upgrade or direct acquire).""" with self._cond: self._writer = False self._upgradable_held = False self._cond.notify_all() def acquire_write(self) -> None: """Acquire exclusive write lock.""" with self._cond: while self._writer or self._readers > 0 or self._upgradable_held: self._cond.wait() self._writer = True
Question 16 - Meeting Scheduler
Difficulty: 6 / 10
Approximate lines of code: 100 LoC
Tags: scheduling, algorithms
Description
A meeting scheduler finds common free time slots across multiple attendees in different timezones. The core algorithm: (1) normalize all busy intervals to a common timezone (UTC), (2) merge overlapping busy intervals, (3) find gaps between busy periods that meet the minimum duration requirement. The key data structure is a list of intervals that gets merged via sorting and linear scan.
Example: Alice (NYC) is busy 9-10am local, Bob (LA) is busy 7-8am local. In UTC, Alice is busy 14:00-15:00, Bob is busy 15:00-16:00. After merging, the combined busy period is 14:00-16:00 UTC.
Part A: Merge Busy Intervals and Find Gaps
Problem: Part A
Given multiple attendees with busy time slots, find all common free slots of at least N minutes.
alice = Attendee( name="Alice", timezone="America/New_York", busy_slots=[TimeSlot(datetime(2024, 1, 15, 9, 0), datetime(2024, 1, 15, 10, 0))])bob = Attendee( name="Bob", timezone="America/Los_Angeles", busy_slots=[TimeSlot(datetime(2024, 1, 15, 14, 0), datetime(2024, 1, 15, 15, 0))])# Search window: 14:00-20:00 UTCslots = find_free_slots([alice, bob], search_start=datetime(2024, 1, 15, 14, 0), # UTC search_end=datetime(2024, 1, 15, 20, 0), # UTC duration_minutes=30)# Internal processing:# 1. Normalize to UTC:# Alice 9am NYC = 14:00 UTC (busy 14:00-15:00)# Bob 2pm LA = 22:00 UTC (outside search window, ignored)# 2. Merge overlapping: [(14:00, 15:00)]# 3. Find gaps >= 30 min: [(15:00, 20:00)]
Interval merge algorithm:
def merge_intervals(slots): sorted_slots = sorted(slots, key=lambda s: s.start) merged = [sorted_slots[0]] for slot in sorted_slots[1:]: if slot.start <= merged[-1].end: merged[-1].end = max(merged[-1].end, slot.end) else: merged.append(slot) return merged
Part B: Timezone Handling
Problem: Part B
Properly convert between local times and UTC. Each attendee specifies their timezone; all calculations happen in UTC.
def normalize_to_utc(slot: TimeSlot, tz: str) -> TimeSlot: local_tz = ZoneInfo(tz) utc = ZoneInfo("UTC") start_utc = slot.start.replace(tzinfo=local_tz).astimezone(utc) end_utc = slot.end.replace(tzinfo=local_tz).astimezone(utc) return TimeSlot(start_utc.replace(tzinfo=None), end_utc.replace(tzinfo=None))# NYC (EST = UTC-5): 9am local -> 14:00 UTC# LA (PST = UTC-8): 9am local -> 17:00 UTC# London (GMT = UTC+0): 9am local -> 09:00 UTC
Key: Use zoneinfo.ZoneInfo (Python 3.9+) or pytz. Never mix naive and aware datetimes.
Part C: Rank Slots by Preference
Problem: Part C
Not all free slots are equal. Rank them by “goodness” - prefer slots during business hours for all attendees, closer to midday.
def score_slot(slot: TimeSlot, people: List[Person]) -> float: """Lower score = better. Penalizes times far from noon for each person.""" total_penalty = 0.0 for person in people: local_start = localize(slot.start, person.timezone) hours_from_noon = abs(local_start.hour - 12) total_penalty += hours_from_noon return total_penalty# Slot at 17:00 UTC:# Alice (NYC): 12:00 local -> penalty = 0# Bob (LA): 09:00 local -> penalty = 3# Total score: 3# Slot at 19:00 UTC:# Alice (NYC): 14:00 local -> penalty = 2# Bob (LA): 11:00 local -> penalty = 1# Total score: 3 (tie, but more balanced)
Also consider: respecting each person’s working hours (9-5 in their timezone), filtering out slots outside those bounds.
Interview comments
Interview comments
Edge cases to probe:
What if there’s no overlapping free time at all?
How do you handle DST transitions?
What if busy slots overlap or are unsorted?
What about the final free slot after the last busy period?
Common mistakes:
Mixing naive and timezone-aware datetimes (causes TypeError or wrong results)
Forgetting the trailing free slot after the last busy period
Not clipping busy intervals to the search window
Using local time comparisons instead of normalizing to UTC first
Off-by-one: is a slot ending at 10:00 and one starting at 10:00 a conflict?
Code solutions
Code solutions
Solution 1 uses interval merging: collect all busy slots, normalize to UTC, merge overlapping intervals, then find gaps between busy periods. Solution 2 uses a sweep line algorithm that converts busy intervals to start/end events and sweeps through time tracking overlap count. Solution 3 adds working hours constraints and slot scoring based on proximity to noon in each attendee’s local timezone. The key difference is how they identify free time (gap-finding vs event sweeping) and whether they consider slot quality.
Core techniques: interval merging via sort-and-scan, sweep line algorithm, timezone normalization with ZoneInfo.
Solution 1: Interval Merging Approach
Interval merging approach. Collects all busy slots, normalizes to UTC, merges overlapping intervals, then finds gaps between busy periods that meet duration requirements.
"""Meeting Scheduler - Solution 1: Interval Merging ApproachFind common free slots across multiple attendees by merging busy intervalsand finding gaps that satisfy the required duration."""from dataclasses import dataclassfrom datetime import datetime, timedeltafrom typing import List, Tuplefrom zoneinfo import ZoneInfo@dataclassclass TimeSlot: start: datetime end: datetime def duration_minutes(self) -> int: return int((self.end - self.start).total_seconds() / 60)@dataclassclass Attendee: name: str timezone: str busy_slots: List[TimeSlot]def normalize_to_utc(slot: TimeSlot, tz: str) -> TimeSlot: """Convert a slot from local timezone to UTC.""" local_tz = ZoneInfo(tz) utc = ZoneInfo("UTC") start_utc = slot.start.replace(tzinfo=local_tz).astimezone(utc) end_utc = slot.end.replace(tzinfo=local_tz).astimezone(utc) return TimeSlot(start_utc.replace(tzinfo=None), end_utc.replace(tzinfo=None))def merge_intervals(slots: List[TimeSlot]) -> List[TimeSlot]: """Merge overlapping intervals.""" if not slots: return [] sorted_slots = sorted(slots, key=lambda s: s.start) merged = [sorted_slots[0]] for slot in sorted_slots[1:]: if slot.start <= merged[-1].end: merged[-1] = TimeSlot(merged[-1].start, max(merged[-1].end, slot.end)) else: merged.append(slot) return mergeddef find_free_slots( attendees: List[Attendee], search_start: datetime, search_end: datetime, duration_minutes: int,) -> List[TimeSlot]: """Find common free slots of at least duration_minutes.""" # Collect all busy slots normalized to UTC all_busy: List[TimeSlot] = [] for attendee in attendees: for slot in attendee.busy_slots: all_busy.append(normalize_to_utc(slot, attendee.timezone)) merged_busy = merge_intervals(all_busy) # Find gaps between busy periods free_slots: List[TimeSlot] = [] current = search_start for busy in merged_busy: if busy.start > current: gap = TimeSlot(current, min(busy.start, search_end)) if gap.duration_minutes() >= duration_minutes: free_slots.append(gap) current = max(current, busy.end) if current < search_end: gap = TimeSlot(current, search_end) if gap.duration_minutes() >= duration_minutes: free_slots.append(gap) return free_slots
Solution 2: Event Sweep Line Algorithm
Sweep line algorithm. Converts busy intervals to start/end events, sweeps through time tracking overlap count. Free time exists where overlap count is zero.
"""Meeting Scheduler - Solution 2: Event Sweep Line AlgorithmUses a sweep line approach: convert all busy intervals to start/end events,sweep through time tracking overlap count, find windows where count is 0."""from dataclasses import dataclassfrom datetime import datetimefrom typing import Listfrom zoneinfo import ZoneInfo@dataclassclass TimeSlot: start: datetime end: datetime def duration_minutes(self) -> int: return int((self.end - self.start).total_seconds() / 60)@dataclassclass Calendar: owner: str timezone: str busy_slots: List[TimeSlot]def to_utc(dt: datetime, tz_name: str) -> datetime: """Convert naive datetime in given timezone to naive UTC datetime.""" tz = ZoneInfo(tz_name) return dt.replace(tzinfo=tz).astimezone(ZoneInfo("UTC")).replace(tzinfo=None)def find_available_slots( calendars: List[Calendar], window_start: datetime, window_end: datetime, min_duration: int,) -> List[TimeSlot]: """Find all slots where everyone is free for at least min_duration minutes.""" # Collect events: (time, delta) where delta is +1 for busy start, -1 for busy end events: List[tuple[datetime, int]] = [] for calendar in calendars: for slot in calendar.busy_slots: utc_start = to_utc(slot.start, calendar.timezone) utc_end = to_utc(slot.end, calendar.timezone) # Clip to search window if utc_end <= window_start or utc_start >= window_end: continue clipped_start = max(utc_start, window_start) clipped_end = min(utc_end, window_end) events.append((clipped_start, 1)) # Start of busy period events.append((clipped_end, -1)) # End of busy period # Sort by time; at same time, process ends (-1) before starts (+1) events.sort(key=lambda e: (e[0], e[1])) available: List[TimeSlot] = [] busy_count = 0 free_start = window_start for time, delta in events: if delta == 1: # Busy period starting if busy_count == 0 and time > free_start: slot = TimeSlot(free_start, time) if slot.duration_minutes() >= min_duration: available.append(slot) busy_count += 1 else: # Busy period ending busy_count -= 1 if busy_count == 0: free_start = time # Handle trailing free time if busy_count == 0 and free_start < window_end: slot = TimeSlot(free_start, window_end) if slot.duration_minutes() >= min_duration: available.append(slot) return availabledef suggest_optimal_times( slots: List[TimeSlot], preferred_hour_utc: int = 14) -> List[TimeSlot]: """Rank slots by proximity to preferred hour (e.g., 2 PM UTC).""" def score(slot: TimeSlot) -> float: hour_diff = abs(slot.start.hour - preferred_hour_utc) return hour_diff + (1 / max(slot.duration_minutes(), 1)) return sorted(slots, key=score)
Solution 3: Working Hours Constraint + Scoring
Working hours constraint with scoring. Respects each attendee’s working hours, scores slots by convenience (proximity to noon in each person’s local time).
"""Meeting Scheduler - Solution 3: Working Hours Constraint + ScoringMore realistic approach: respects working hours per attendee,scores slots by convenience (avoiding early/late times across timezones)."""from dataclasses import dataclass, fieldfrom datetime import datetime, timedeltafrom typing import Listfrom zoneinfo import ZoneInfo@dataclassclass TimeSlot: start: datetime end: datetime def duration_minutes(self) -> int: return int((self.end - self.start).total_seconds() / 60)@dataclassclass WorkingHours: start_hour: int = 9 end_hour: int = 17@dataclassclass Person: name: str timezone: str busy: List[TimeSlot] = field(default_factory=list) working_hours: WorkingHours = field(default_factory=WorkingHours)def localize(dt: datetime, tz_name: str) -> datetime: """Convert naive UTC datetime to naive local datetime.""" utc = ZoneInfo("UTC") local = ZoneInfo(tz_name) return dt.replace(tzinfo=utc).astimezone(local).replace(tzinfo=None)def to_utc(dt: datetime, tz_name: str) -> datetime: """Convert naive local datetime to naive UTC datetime.""" local = ZoneInfo(tz_name) utc = ZoneInfo("UTC") return dt.replace(tzinfo=local).astimezone(utc).replace(tzinfo=None)def is_within_working_hours(dt_utc: datetime, person: Person) -> bool: """Check if UTC time falls within person's working hours.""" local_dt = localize(dt_utc, person.timezone) return person.working_hours.start_hour <= local_dt.hour < person.working_hours.end_hourdef is_busy(dt_utc: datetime, person: Person) -> bool: """Check if person has a conflicting meeting at given UTC time.""" for slot in person.busy: slot_start = to_utc(slot.start, person.timezone) slot_end = to_utc(slot.end, person.timezone) if slot_start <= dt_utc < slot_end: return True return Falsedef score_slot(slot: TimeSlot, people: List[Person]) -> float: """Score a slot: lower is better. Penalizes times far from midday for each person.""" total_penalty = 0.0 for person in people: local_start = localize(slot.start, person.timezone) # Ideal time is noon; penalize distance from it hours_from_noon = abs(local_start.hour + local_start.minute / 60 - 12) total_penalty += hours_from_noon return total_penaltydef find_meeting_slots( people: List[Person], search_date: datetime, duration_minutes: int, step_minutes: int = 15,) -> List[tuple[TimeSlot, float]]: """Find and rank available meeting slots for a given date.""" # Search window: 00:00 to 23:59 UTC on the given date window_start = search_date.replace(hour=0, minute=0, second=0, microsecond=0) window_end = window_start + timedelta(days=1) candidates: List[tuple[TimeSlot, float]] = [] current = window_start while current + timedelta(minutes=duration_minutes) <= window_end: slot_end = current + timedelta(minutes=duration_minutes) slot = TimeSlot(current, slot_end) # Check all time points within the slot all_available = True check_time = current while check_time < slot_end: for person in people: if not is_within_working_hours(check_time, person): all_available = False break if is_busy(check_time, person): all_available = False break if not all_available: break check_time += timedelta(minutes=step_minutes) if all_available: candidates.append((slot, score_slot(slot, people))) current += timedelta(minutes=step_minutes) # Sort by score (lower is better) candidates.sort(key=lambda x: x[1]) return candidates
Question 17 - LSM Tree
Difficulty: 10 / 10
Approximate lines of code: 90 LoC
Tags: storage, data-structures
Description
An LSM (Log-Structured Merge) Tree is the storage engine behind modern databases like LevelDB, RocksDB, and Cassandra. It optimizes for write-heavy workloads by buffering writes in memory and periodically flushing them to disk as immutable sorted files called SSTables (Sorted String Tables). Reads check the in-memory buffer first, then search through SSTables from newest to oldest.
The key data structures are: (1) a memtable (typically a balanced tree or hash map) that holds recent writes in memory, (2) a list of SSTables on disk, each containing sorted key-value pairs, and (3) optional bloom filters to skip SSTables that definitely don’t contain a key.
Note: The sortedcontainers library (SortedDict) is available. Using SortedDict for the memtable gives sorted iteration during flush without a separate sort step.
Part A: Basic Operations
Problem: Part A
Implement a key-value store with put(key, value) and get(key). When the memtable exceeds a size threshold, flush it to a new SSTable. The critical insight: reads must check the memtable first, then SSTables in newest-to-oldest order (the first match wins).
lsm = LSMTree(memtable_threshold=3)lsm.put("a", "1")lsm.put("b", "2")# Internal state: memtable = {"a": "1", "b": "2"}, sstables = []lsm.put("c", "3") # Triggers flush# Internal state: memtable = {}, sstables = [SSTable([("a","1"), ("b","2"), ("c","3")])]lsm.put("a", "10") # Update in new memtable# Internal state: memtable = {"a": "10"}, sstables = [SSTable([...])]lsm.get("a") # Returns "10" (found in memtable, doesn't check SSTable)lsm.get("b") # Returns "2" (not in memtable, found in SSTable via binary search)
Part B: Delete Support
Problem: Part B
Add delete(key). The trick: you cannot simply remove the key from the memtable because older SSTables may still contain it. Instead, write a tombstone marker (e.g., None value) that shadows any older values.
lsm.put("x", "100")lsm.put("y", "200")lsm.put("z", "300") # Flush happens# sstables[0] = [("x","100"), ("y","200"), ("z","300")]lsm.delete("x")# memtable = {"x": None} <-- tombstonelsm.get("x") # Returns None (tombstone found in memtable)lsm.get("y") # Returns "200" (no tombstone, found in SSTable)
Part C: Compaction
Problem: Part C
With many SSTables, reads become slow (must check each one). Implement compact() that merges all SSTables together, keeping only the newest value for each key and removing tombstones for keys that have no older versions.
# Before compaction:# sstables[0] = [("a", "new")] # Newer# sstables[1] = [("a", "old"), ("b", "2")] # Olderlsm.compact()# After: sstables = [SSTable([("a", "new"), ("b", "2")])]# Tombstone cleanup:# If sstables[0] = [("x", None)] and sstables[1] = [("x", "old")]# After compact: key "x" is removed entirely (tombstone + old value both gone)
Interview comments
Interview comments
Edge cases to probe:
What happens if you get() a key that was deleted? (Should return None)
What if you delete a key that doesn’t exist? (Write tombstone anyway - it might exist in an SSTable you haven’t checked)
What order do you check SSTables? (Newest first - order matters!)
When can you safely remove tombstones? (Only during compaction, when you merge with all older SSTables)
Common mistakes:
Checking SSTables before memtable, or checking oldest SSTable first
Deleting directly from memtable without writing a tombstone
Removing tombstones during flush (too early - older SSTables still have the value)
Off-by-one errors in binary search within SSTables
Forgetting that bisect_left needs proper tuple comparison
Code solutions
Code solutions
Solution 1 uses an in-memory dict for the memtable and sorted lists for SSTables with binary search lookups. Solution 2 writes SSTables to JSON files on disk and adds bloom filters to skip unnecessary reads. Solution 3 implements tiered compaction with multiple levels, modeling how production systems like LevelDB organize data. The key differences are in storage medium (memory vs files) and compaction strategy (single merge vs tiered levels). Core techniques: binary search, hash maps, bloom filters, merge sort.
Solution 1: In-memory with sorted lists
Basic in-memory implementation using a dict for the memtable and sorted lists for SSTables. Uses binary search via bisect for efficient SSTable lookups. Simple and clean - good for demonstrating core concepts.
"""LSM Tree - Basic Implementation with in-memory segments."""from dataclasses import dataclass, fieldfrom typing import Optionalimport bisect@dataclassclass SSTable: """Immutable sorted segment stored on 'disk' (simulated in memory).""" data: list[tuple[str, Optional[str]]] # Sorted list of (key, value) pairs def get(self, key: str) -> tuple[bool, Optional[str]]: """Binary search for key. Returns (found, value).""" idx = bisect.bisect_left(self.data, (key,)) if idx < len(self.data) and self.data[idx][0] == key: return True, self.data[idx][1] return False, None@dataclassclass LSMTree: """Log-Structured Merge Tree with memtable and SSTables.""" memtable: dict[str, Optional[str]] = field(default_factory=dict) sstables: list[SSTable] = field(default_factory=list) # Newest first memtable_threshold: int = 4 def put(self, key: str, value: str) -> None: """Insert or update a key-value pair.""" self.memtable[key] = value if len(self.memtable) >= self.memtable_threshold: self._flush() def delete(self, key: str) -> None: """Delete a key by writing a tombstone.""" self.memtable[key] = None def get(self, key: str) -> Optional[str]: """Read a value. Checks memtable first, then SSTables newest to oldest.""" if key in self.memtable: return self.memtable[key] # Returns None for tombstones for sstable in self.sstables: found, value = sstable.get(key) if found: return value return None def _flush(self) -> None: """Flush memtable to a new SSTable.""" if not self.memtable: return sorted_data = sorted(self.memtable.items()) self.sstables.insert(0, SSTable(data=sorted_data)) self.memtable.clear() def compact(self) -> None: """Merge all SSTables into one, removing tombstones.""" self._flush() if len(self.sstables) <= 1: return merged: dict[str, Optional[str]] = {} for sstable in reversed(self.sstables): # Oldest first for key, value in sstable.data: merged[key] = value # Remove tombstones live_data = [(k, v) for k, v in sorted(merged.items()) if v is not None] self.sstables = [SSTable(data=live_data)] if live_data else []
Solution 2: File-based with bloom filters
File-based implementation that writes SSTables to JSON files. Adds bloom filters for each SSTable to skip disk reads when a key definitely isn’t present. More realistic for understanding actual LSM tree I/O patterns.
"""LSM Tree - File-based simulation with bloom filters."""from dataclasses import dataclass, fieldfrom typing import Optionalimport hashlibimport jsonfrom pathlib import Pathimport tempfileimport shutil@dataclassclass BloomFilter: """Simple bloom filter for probabilistic key existence checks.""" size: int = 64 bits: int = 0 def _hashes(self, key: str) -> list[int]: h = hashlib.md5(key.encode()).hexdigest() return [int(h[i:i+4], 16) % self.size for i in range(0, 12, 4)] def add(self, key: str) -> None: for h in self._hashes(key): self.bits |= (1 << h) def might_contain(self, key: str) -> bool: return all((self.bits & (1 << h)) for h in self._hashes(key))@dataclassclass SSTable: """SSTable stored as a JSON file with bloom filter.""" path: Path bloom: BloomFilter = field(default_factory=BloomFilter) def get(self, key: str) -> tuple[bool, Optional[str]]: if not self.bloom.might_contain(key): return False, None data = json.loads(self.path.read_text()) if key in data: return True, data[key] return False, None def items(self) -> list[tuple[str, Optional[str]]]: return list(json.loads(self.path.read_text()).items())@dataclassclass LSMTree: """LSM Tree with file-based SSTables and bloom filters.""" base_dir: Path memtable: dict[str, Optional[str]] = field(default_factory=dict) sstables: list[SSTable] = field(default_factory=list) memtable_threshold: int = 4 _counter: int = 0 def put(self, key: str, value: str) -> None: self.memtable[key] = value if len(self.memtable) >= self.memtable_threshold: self._flush() def delete(self, key: str) -> None: self.memtable[key] = None def get(self, key: str) -> Optional[str]: if key in self.memtable: return self.memtable[key] for sstable in self.sstables: found, value = sstable.get(key) if found: return value return None def _flush(self) -> None: if not self.memtable: return bloom = BloomFilter() for key in self.memtable: bloom.add(key) path = self.base_dir / f"sstable_{self._counter}.json" path.write_text(json.dumps(dict(sorted(self.memtable.items())))) self._counter += 1 self.sstables.insert(0, SSTable(path=path, bloom=bloom)) self.memtable.clear() def compact(self) -> None: self._flush() if len(self.sstables) <= 1: return merged: dict[str, Optional[str]] = {} for sstable in reversed(self.sstables): for k, v in sstable.items(): merged[k] = v for sstable in self.sstables: sstable.path.unlink() live = {k: v for k, v in merged.items() if v is not None} self.sstables.clear() if live: self.memtable = live self._flush()
Solution 3: Tiered compaction with multiple levels
Tiered compaction with multiple levels. SSTables are organized into levels; when a level fills up, its tables merge and push to the next level. This models how production systems like LevelDB organize data for better read amplification.
"""LSM Tree - Tiered compaction with multiple levels."""from dataclasses import dataclass, fieldfrom typing import Optionalimport bisect@dataclassclass SSTable: """Immutable sorted segment with size tracking.""" data: list[tuple[str, Optional[str]]] def get(self, key: str) -> tuple[bool, Optional[str]]: idx = bisect.bisect_left(self.data, (key,)) if idx < len(self.data) and self.data[idx][0] == key: return True, self.data[idx][1] return False, None def __len__(self) -> int: return len(self.data)@dataclassclass Level: """A level in the LSM tree containing multiple SSTables.""" sstables: list[SSTable] = field(default_factory=list) max_tables: int = 4 def is_full(self) -> bool: return len(self.sstables) >= self.max_tables def add(self, sstable: SSTable) -> None: self.sstables.insert(0, sstable) def merge_all(self) -> SSTable: merged: dict[str, Optional[str]] = {} for sstable in reversed(self.sstables): for k, v in sstable.data: merged[k] = v self.sstables.clear() return SSTable(data=sorted(merged.items()))@dataclassclass LSMTree: """LSM Tree with tiered compaction across multiple levels.""" memtable: dict[str, Optional[str]] = field(default_factory=dict) levels: list[Level] = field(default_factory=list) memtable_threshold: int = 4 num_levels: int = 3 def __post_init__(self) -> None: self.levels = [Level(max_tables=4) for _ in range(self.num_levels)] def put(self, key: str, value: str) -> None: self.memtable[key] = value if len(self.memtable) >= self.memtable_threshold: self._flush() def delete(self, key: str) -> None: self.memtable[key] = None def get(self, key: str) -> Optional[str]: if key in self.memtable: return self.memtable[key] for level in self.levels: for sstable in level.sstables: found, value = sstable.get(key) if found: return value return None def _flush(self) -> None: if not self.memtable: return sstable = SSTable(data=sorted(self.memtable.items())) self.memtable.clear() self._add_to_level(0, sstable) def _add_to_level(self, level_idx: int, sstable: SSTable) -> None: if level_idx >= len(self.levels): self.levels.append(Level(max_tables=4)) level = self.levels[level_idx] level.add(sstable) if level.is_full(): merged = level.merge_all() self._add_to_level(level_idx + 1, merged) def compact(self) -> None: """Force full compaction, removing tombstones from final level.""" self._flush() for i in range(len(self.levels) - 1): if self.levels[i].sstables: merged = self.levels[i].merge_all() self._add_to_level(i + 1, merged) # Remove tombstones from last level if self.levels and self.levels[-1].sstables: final = self.levels[-1].merge_all() live = [(k, v) for k, v in final.data if v is not None] if live: self.levels[-1].sstables = [SSTable(data=live)]
Question 18 - Undo/Redo
Difficulty: 4 / 10
Approximate lines of code: 100 LoC
Tags: data-structures
Description
An undo/redo system allows users to reverse and replay operations in applications like text editors, graphics software, or games. The key insight is that operations must be reversible - you need to store enough information to reconstruct the previous state. There are two main approaches: the command pattern (store operations with their inverses) or the memento pattern (store full state snapshots). Command pattern uses less memory but requires each operation to be invertible; memento is simpler but stores redundant data.
Internal state typically includes two stacks: an undo stack (executed commands waiting to be undone) and a redo stack (undone commands waiting to be replayed). When a new command executes, it clears the redo stack - you cannot redo after making a new change.
Part A: Execute, Undo, Redo
Problem: Part A
Implement a basic undo/redo manager with three operations: execute(command), undo(), and redo(). Each command must support both execution and reversal.
manager = UndoRedoManager()doc = Document(content="")manager.execute(InsertCommand(doc, 0, "Hello"))# doc.content = "Hello"# undo_stack = [InsertCommand("Hello" at 0)]# redo_stack = []manager.execute(InsertCommand(doc, 5, " World"))# doc.content = "Hello World"# undo_stack = [InsertCommand("Hello" at 0), InsertCommand(" World" at 5)]manager.undo()# doc.content = "Hello"# undo_stack = [InsertCommand("Hello" at 0)]# redo_stack = [InsertCommand(" World" at 5)]manager.redo()# doc.content = "Hello World"# undo_stack = [InsertCommand("Hello" at 0), InsertCommand(" World" at 5)]# redo_stack = []manager.execute(InsertCommand(doc, 5, "!"))# doc.content = "Hello! World"# undo_stack = [..., InsertCommand("!" at 5)]# redo_stack = [] # CLEARED - new command invalidates redo history
Part B: Command Grouping
Problem: Part B
Add support for grouping multiple commands into a single undoable operation. When the user undoes a group, all commands in the group are reversed together in reverse order.
group = CommandGroup([ InsertCommand(doc, 0, "A"), InsertCommand(doc, 1, "B"), InsertCommand(doc, 2, "C"),])manager.execute(group)# doc.content = "ABC"# undo_stack = [CommandGroup([A, B, C])]manager.undo()# Internally: undo C, then B, then A (reverse order!)# doc.content = ""
Part C: Memory Limits
Problem: Part C
Add a maximum history size. When the undo stack exceeds the limit, discard the oldest commands. Consider: should you track memory by command count or actual bytes? How do you handle groups that push you over the limit?
What happens when you undo with an empty undo stack?
What happens when you redo after executing a new command?
How do you undo a delete operation? (Must store the deleted content)
What if undo itself fails partway through a group?
Common mistakes:
Forgetting to clear the redo stack when executing a new command
Undoing grouped commands in forward order instead of reverse order
Not storing enough state to reverse delete operations
Shallow copying mutable state in memento pattern (changes leak through)
Code solutions
Code solutions
Solution 1 uses the classic command pattern with an abstract base class where each command implements execute() and undo() methods. Solution 2 takes a functional approach using closures, where commands are functions that return their undo function when executed. Solution 3 uses the memento pattern, storing complete state snapshots instead of reversible commands. The key difference is the trade-off between memory usage and implementation complexity: command pattern is memory-efficient but requires invertible operations, memento is simpler but stores redundant state. Core techniques: command pattern, closure-based undo, state snapshots with deepcopy, stack-based history management.
Solution 1: Classic Command Pattern
Classic command pattern using an abstract base class. Each command (Insert, Delete) implements execute() and undo() methods. Commands store the information needed to reverse themselves. Groups are commands containing other commands.
Solution 2: Functional Approach with Closures
Functional approach using closures. Commands are functions that return their undo function when executed. Cleaner for simple cases with no class hierarchy needed. Must store both the command and its undo function to support redo.
Solution 3: Memento Pattern (State Snapshots)
Memento pattern using state snapshots. Instead of reversible commands, store complete copies of state after each operation. Trades memory for simplicity - no need to implement undo logic per operation. Uses deepcopy to prevent mutation leakage.
"""Undo/Redo System - Solution 3: Memento Pattern (State Snapshots)Instead of reversible commands, store full state snapshots.Simpler logic but higher memory usage. Good for small state objects."""from dataclasses import dataclass, fieldfrom typing import Generic, TypeVarfrom copy import deepcopyT = TypeVar('T')@dataclassclass GameState: """Example state: a simple game with position and score.""" x: int = 0 y: int = 0 score: int = 0 inventory: list[str] = field(default_factory=list)@dataclassclass UndoRedoManager(Generic[T]): """ Manages undo/redo using state snapshots (memento pattern). Stores complete copies of state, trading memory for simplicity. """ history: list[T] = field(default_factory=list) current_index: int = -1 def save_state(self, state: T) -> None: """Save a snapshot after each mutation.""" # Discard any redo history when new state is saved self.history = self.history[:self.current_index + 1] self.history.append(deepcopy(state)) self.current_index += 1 def undo(self) -> T | None: """Return the previous state, or None if at beginning.""" if self.current_index <= 0: return None self.current_index -= 1 return deepcopy(self.history[self.current_index]) def redo(self) -> T | None: """Return the next state, or None if at end.""" if self.current_index >= len(self.history) - 1: return None self.current_index += 1 return deepcopy(self.history[self.current_index]) def can_undo(self) -> bool: return self.current_index > 0 def can_redo(self) -> bool: return self.current_index < len(self.history) - 1@dataclassclass GameEngine: """Game engine that uses the undo/redo manager.""" state: GameState = field(default_factory=GameState) history: UndoRedoManager[GameState] = field(default_factory=UndoRedoManager) def __post_init__(self) -> None: self.history.save_state(self.state) def move(self, dx: int, dy: int) -> None: self.state.x += dx self.state.y += dy self.history.save_state(self.state) def add_score(self, points: int) -> None: self.state.score += points self.history.save_state(self.state) def collect_item(self, item: str) -> None: self.state.inventory.append(item) self.history.save_state(self.state) def batch_actions(self, actions: list[tuple[str, ...]]) -> None: """Execute multiple actions as a single undoable operation.""" for action in actions: name, *args = action if name == "move": self.state.x += int(args[0]) self.state.y += int(args[1]) elif name == "score": self.state.score += int(args[0]) elif name == "collect": self.state.inventory.append(str(args[0])) self.history.save_state(self.state) def undo(self) -> bool: prev_state = self.history.undo() if prev_state is None: return False self.state = prev_state return True def redo(self) -> bool: next_state = self.history.redo() if next_state is None: return False self.state = next_state return True
Question 19 - Interval Tree
Difficulty: 9 / 10
Approximate lines of code: 100 LoC
Tags: data-structures
Description
An interval tree efficiently stores intervals (like meeting times, IP ranges, or genomic regions) and answers queries about which intervals contain a point or overlap with a query range. The naive approach of checking every interval is O(n) per query. An interval tree achieves O(log n + k) where k is the number of results by augmenting a BST with subtree metadata that enables pruning entire branches.
The key insight is storing max_end at each node: the maximum endpoint among all intervals in that subtree. When querying for point P, if P > node.max_end, you can skip the entire subtree since no interval there can contain P. The tree is ordered by interval start points (like a normal BST), with the max_end augmentation propagated up during insertions and deletions.
Note: The sortedcontainers library (SortedList) is available. If using a sorted list approach instead of a tree, SortedList gives O(log n) insertions instead of O(n) with bisect.insort.
Part A: Basic Structure
Problem: Part A
Build an interval tree supporting insert(start, end) and query_point(point). Each node stores an interval and the maximum endpoint in its subtree. Insert maintains BST ordering by start point and updates max_end on the path to the root.
tree = IntervalTree()tree.insert(15, 20)tree.insert(10, 30)tree.insert(5, 10)tree.insert(17, 19)# Tree structure (ordered by start):# [10,30] max_end=30# / \# [5,10] max_end=10 [15,20] max_end=20# \# [17,19] max_end=19result = tree.query_point(18)# Returns: [[10,30], [15,20], [17,19]]# Query traverses: checks [10,30] (contains 18), recurses left (max_end=10 < 18, prune),# recurses right to [15,20] (contains 18), then [17,19] (contains 18)result = tree.query_point(7)# Returns: [[5,10]]# [10,30] doesn't contain 7, but max_end=30 >= 7 so we check children# Left subtree [5,10] contains 7
Part B: Overlap Queries
Problem: Part B
Add query_range(start, end) that finds all intervals overlapping with the query interval. Two intervals [a,b] and [c,d] overlap if a <= d AND c <= b. Use the same pruning strategy with max_end.
tree = IntervalTree()tree.insert(15, 20)tree.insert(10, 30)tree.insert(5, 10)tree.insert(25, 30)result = tree.query_range(12, 16)# Returns: [[10,30], [15,20]]# [10,30] overlaps [12,16]: 10 <= 16 AND 12 <= 30# [15,20] overlaps [12,16]: 15 <= 16 AND 12 <= 20# [5,10] doesn't overlap: 5 <= 16 but 12 > 10# [25,30] doesn't overlap: 25 > 16# Edge case: touching intervals (share an endpoint)result = tree.query_range(10, 10)# Returns: [[5,10], [10,30]] - both touch point 10
Part C: Deletion
Problem: Part C
Implement delete(start, end) that removes an interval and maintains the max_end invariant. After removing a node, recompute max_end for all ancestors by taking the max of the node’s interval end and the max_end of its children.
tree = IntervalTree()tree.insert(15, 20)tree.insert(10, 30)tree.insert(17, 19)# Before delete:# [10,30] max_end=30# \# [15,20] max_end=20# \# [17,19] max_end=19tree.delete(17, 19)# After delete:# [10,30] max_end=30# \# [15,20] max_end=20 <-- max_end recomputed# Deleting a node with two children uses standard BST deletion# (replace with in-order successor), then recompute max_end up the path
Interview comments
Interview comments
Edge cases to probe:
Do [1,5] and [5,10] overlap? (Yes, they touch at 5 - use ⇐ not <)
What if you insert single-point intervals like [7,7]?
How do you handle deleting a node with two children? (In-order successor, then recompute max_end)
What’s the worst-case query time? (O(n) if all intervals overlap the query)
Common mistakes:
Wrong overlap check: using < instead of <= (missing touching intervals)
Forgetting to update max_end on insert path
Not recomputing max_end correctly after delete (must check both children)
Pruning incorrectly: should check query.start > node.max_end not query.end
Off-by-one in deciding whether to recurse into children
Code solutions
Code solutions
Solution 1 uses an augmented BST where each node stores max_end (maximum endpoint in subtree) for efficient pruning during queries. Solution 2 is a simpler sorted list approach with binary search for insertion but O(n) queries. Solution 3 implements a centered interval tree that partitions intervals by a center point. They differ in the tradeoff between implementation complexity and query efficiency.
Solution 1: Augmented BST approach
Augmented BST approach where each node stores the interval and max_end (maximum endpoint in subtree). Insert/delete maintain the invariant. Query uses max_end to prune branches.
"""Interval Tree - Solution 1: Augmented BST approachUses an augmented binary search tree where each node stores the max endpoint in its subtree."""from dataclasses import dataclassfrom typing import Optional@dataclassclass Interval: start: int end: int def overlaps(self, other: "Interval") -> bool: return self.start <= other.end and other.start <= self.end def contains_point(self, point: int) -> bool: return self.start <= point <= self.end@dataclassclass Node: interval: Interval max_end: int left: Optional["Node"] = None right: Optional["Node"] = Noneclass IntervalTree: def __init__(self) -> None: self.root: Optional[Node] = None def insert(self, start: int, end: int) -> None: interval = Interval(start, end) self.root = self._insert(self.root, interval) def _insert(self, node: Optional[Node], interval: Interval) -> Node: if node is None: return Node(interval=interval, max_end=interval.end) if interval.start < node.interval.start: node.left = self._insert(node.left, interval) else: node.right = self._insert(node.right, interval) node.max_end = max(node.max_end, interval.end) return node def query_point(self, point: int) -> list[Interval]: results: list[Interval] = [] self._query_point(self.root, point, results) return results def _query_point(self, node: Optional[Node], point: int, results: list[Interval]) -> None: if node is None or point > node.max_end: return if node.interval.contains_point(point): results.append(node.interval) self._query_point(node.left, point, results) if point >= node.interval.start: self._query_point(node.right, point, results) def query_range(self, start: int, end: int) -> list[Interval]: query_interval = Interval(start, end) results: list[Interval] = [] self._query_range(self.root, query_interval, results) return results def _query_range(self, node: Optional[Node], query: Interval, results: list[Interval]) -> None: if node is None or query.start > node.max_end: return if node.interval.overlaps(query): results.append(node.interval) self._query_range(node.left, query, results) if query.end >= node.interval.start: self._query_range(node.right, query, results) def delete(self, start: int, end: int) -> bool: interval = Interval(start, end) self.root, deleted = self._delete(self.root, interval) return deleted def _delete(self, node: Optional[Node], interval: Interval) -> tuple[Optional[Node], bool]: if node is None: return None, False if node.interval.start == interval.start and node.interval.end == interval.end: if node.left is None: return node.right, True if node.right is None: return node.left, True successor = self._min_node(node.right) node.interval = successor.interval node.right, _ = self._delete(node.right, successor.interval) elif interval.start < node.interval.start: node.left, deleted = self._delete(node.left, interval) if not deleted: return node, False else: node.right, deleted = self._delete(node.right, interval) if not deleted: return node, False node.max_end = self._compute_max(node) return node, True def _min_node(self, node: Node) -> Node: while node.left: node = node.left return node def _compute_max(self, node: Node) -> int: max_end = node.interval.end if node.left: max_end = max(max_end, node.left.max_end) if node.right: max_end = max(max_end, node.right.max_end) return max_end
Solution 2: Sorted list with binary search
Sorted list approach using binary search for simpler implementation. O(n) queries but O(1) to understand and debug. Good baseline for smaller datasets.
"""Interval Tree - Solution 2: Sorted list with binary searchSimpler O(n) approach using a sorted list. Good for smaller datasets or whensimplicity matters more than performance."""from dataclasses import dataclassfrom bisect import insort_left, bisect_left@dataclass(frozen=True)class Interval: start: int end: int def overlaps(self, other: "Interval") -> bool: return self.start <= other.end and other.start <= self.end def contains_point(self, point: int) -> bool: return self.start <= point <= self.end def __lt__(self, other: "Interval") -> bool: return (self.start, self.end) < (other.start, other.end)class IntervalTree: """Simple interval container using sorted list.""" def __init__(self) -> None: self._intervals: list[Interval] = [] def insert(self, start: int, end: int) -> None: if start > end: raise ValueError(f"Invalid interval: start {start} > end {end}") interval = Interval(start, end) insort_left(self._intervals, interval) def query_point(self, point: int) -> list[Interval]: results: list[Interval] = [] for interval in self._intervals: if interval.start > point: break if interval.contains_point(point): results.append(interval) return results def query_range(self, start: int, end: int) -> list[Interval]: query = Interval(start, end) results: list[Interval] = [] for interval in self._intervals: if interval.start > end: break if interval.overlaps(query): results.append(interval) return results def delete(self, start: int, end: int) -> bool: target = Interval(start, end) idx = bisect_left(self._intervals, target) if idx < len(self._intervals) and self._intervals[idx] == target: self._intervals.pop(idx) return True return False def __len__(self) -> int: return len(self._intervals) def __iter__(self): return iter(self._intervals)
Solution 3: Centered Interval Tree
Centered interval tree approach. Picks a center point and stores intervals spanning it at the node, with intervals entirely left/right recursively placed in children. Maintains two sorted lists (by start and by end descending) for efficient point queries.
"""Interval Tree - Solution 3: Centered Interval TreeDivides intervals by a center point. Intervals spanning the center are stored at the node,others recursively placed in left/right subtrees."""from dataclasses import dataclass, fieldfrom typing import Optional@dataclass(frozen=True)class Interval: start: int end: int def overlaps(self, other: "Interval") -> bool: return self.start <= other.end and other.start <= self.end def contains_point(self, point: int) -> bool: return self.start <= point <= self.end@dataclassclass CenteredNode: center: int by_start: list[Interval] = field(default_factory=list) # sorted by start by_end: list[Interval] = field(default_factory=list) # sorted by end desc left: Optional["CenteredNode"] = None right: Optional["CenteredNode"] = Noneclass IntervalTree: def __init__(self) -> None: self._intervals: set[Interval] = set() self._root: Optional[CenteredNode] = None def insert(self, start: int, end: int) -> None: interval = Interval(start, end) self._intervals.add(interval) self._rebuild() def delete(self, start: int, end: int) -> bool: interval = Interval(start, end) if interval not in self._intervals: return False self._intervals.remove(interval) self._rebuild() return True def _rebuild(self) -> None: self._root = self._build(list(self._intervals)) def _build(self, intervals: list[Interval]) -> Optional[CenteredNode]: if not intervals: return None points = sorted(set(p for i in intervals for p in (i.start, i.end))) center = points[len(points) // 2] node = CenteredNode(center=center) left_intervals: list[Interval] = [] right_intervals: list[Interval] = [] for interval in intervals: if interval.end < center: left_intervals.append(interval) elif interval.start > center: right_intervals.append(interval) else: node.by_start.append(interval) node.by_end.append(interval) node.by_start.sort(key=lambda i: i.start) node.by_end.sort(key=lambda i: i.end, reverse=True) node.left = self._build(left_intervals) node.right = self._build(right_intervals) return node def query_point(self, point: int) -> list[Interval]: results: list[Interval] = [] self._query_point(self._root, point, results) return results def _query_point(self, node: Optional[CenteredNode], point: int, results: list[Interval]) -> None: if node is None: return if point < node.center: for interval in node.by_start: if interval.start > point: break if interval.contains_point(point): results.append(interval) self._query_point(node.left, point, results) elif point > node.center: for interval in node.by_end: if interval.end < point: break if interval.contains_point(point): results.append(interval) self._query_point(node.right, point, results) else: results.extend(node.by_start) def query_range(self, start: int, end: int) -> list[Interval]: query = Interval(start, end) return [i for i in self._intervals if i.overlaps(query)]
Question 20 - Sparse Matrix
Difficulty: 3 / 10
Approximate lines of code: 70 LoC
Tags: data-structures
Description
A sparse matrix is a matrix where most elements are zero. Instead of storing all n*m elements, we only store non-zero values with their positions. Common representations include DOK (dictionary of keys) using {(row, col): value}, CSR (compressed sparse row) using arrays for values, column indices, and row pointers, and row-based dictionaries using {row: {col: value}}. The choice depends on access patterns: DOK is simple and good for random access, CSR is efficient for row operations and matrix multiplication.
The key insight for multiplication is that you only need to iterate over non-zero elements: for each non-zero A[i,k], multiply with each non-zero B[k,j] and accumulate into C[i,j]. This gives O(nnz_A * avg_row_nnz_B) instead of O(n^3).
Part A: Basic Get/Set Operations
Problem: Part A
Implement get(row, col) and set(row, col, value). Setting a value to 0 should remove it from storage (not store zeros). Accessing an unset position returns 0.
m = SparseMatrix(rows=3, cols=3)m.set(0, 0, 1.0)m.set(1, 1, 2.0)m.set(2, 2, 3.0)# Internal state (DOK): {(0,0): 1.0, (1,1): 2.0, (2,2): 3.0}m.get(0, 0) # 1.0m.get(0, 1) # 0.0 (not stored, returns default)m.get(1, 1) # 2.0m.set(1, 1, 0.0)# Internal state: {(0,0): 1.0, (2,2): 3.0}# Zero was removed, not stored
Part B: Addition and Transpose
Problem: Part B
Implement matrix addition (A + B) and transpose (swap rows and columns). Addition only needs to iterate over non-zero positions from both matrices.
a = SparseMatrix(3, 3)a.set(0, 0, 1.0)a.set(1, 1, 2.0)b = SparseMatrix(3, 3)b.set(0, 0, 3.0) # Overlap with ab.set(1, 2, 5.0) # New positionc = a.add(b)c.get(0, 0) # 4.0 (1.0 + 3.0)c.get(1, 1) # 2.0 (only in a)c.get(1, 2) # 5.0 (only in b)# Transposem = SparseMatrix(2, 3)m.set(0, 2, 7.0)t = m.transpose() # Now 3x2t.get(2, 0) # 7.0 (was at row=0, col=2)
Part C: Matrix Multiplication
Problem: Part C
Implement matrix multiplication C = A * B. The efficient approach iterates over non-zero elements of A, and for each A[i,k], multiplies with the non-zero elements in row k of B.
What’s the complexity of your multiply vs dense O(n^3)?
Common mistakes:
Storing zeros (defeats the purpose of sparse storage)
O(n^3) multiplication iterating over all indices
Not checking dimension compatibility
Accumulating into result without initializing (missing values)
Code solutions
Code solutions
Solution 1 uses Dictionary of Keys (DOK) representation with {(row, col): value}, which is simple and good for random access. Solution 2 uses Compressed Sparse Row (CSR) format with three arrays (values, column indices, row pointers), which is more memory-efficient and cache-friendly for row operations. Solution 3 uses row-based dictionaries {row: {col: value}}, balancing DOK simplicity with efficient row access patterns. The key difference is the storage format and its access pattern tradeoffs: tuple keys vs compressed arrays vs nested dictionaries. Core techniques: sparse representations (DOK, CSR), efficient matrix multiplication via non-zero iteration.
Solution 1: Dictionary of Keys (DOK) representation
Dictionary of Keys (DOK) representation using {(row, col): value}. Simple and intuitive, good for random access and incremental construction.
"""Sparse Matrix - Solution 1: Dictionary of Keys (DOK) representation.Simple and intuitive approach using a dictionary with (row, col) tuples as keys."""from dataclasses import dataclass, fieldfrom typing import Dict, Tuple@dataclassclass SparseMatrix: rows: int cols: int data: Dict[Tuple[int, int], float] = field(default_factory=dict) def get(self, row: int, col: int) -> float: if not (0 <= row < self.rows and 0 <= col < self.cols): raise IndexError(f"Index ({row}, {col}) out of bounds") return self.data.get((row, col), 0.0) def set(self, row: int, col: int, value: float) -> None: if not (0 <= row < self.rows and 0 <= col < self.cols): raise IndexError(f"Index ({row}, {col}) out of bounds") if value == 0.0: self.data.pop((row, col), None) else: self.data[(row, col)] = value def add(self, other: "SparseMatrix") -> "SparseMatrix": if self.rows != other.rows or self.cols != other.cols: raise ValueError("Matrix dimensions must match for addition") result = SparseMatrix(self.rows, self.cols) all_keys = set(self.data.keys()) | set(other.data.keys()) for key in all_keys: val = self.data.get(key, 0.0) + other.data.get(key, 0.0) if val != 0.0: result.data[key] = val return result def multiply(self, other: "SparseMatrix") -> "SparseMatrix": if self.cols != other.rows: raise ValueError(f"Cannot multiply {self.rows}x{self.cols} by {other.rows}x{other.cols}") result = SparseMatrix(self.rows, other.cols) for (i, k), val_a in self.data.items(): for j in range(other.cols): if (k, j) in other.data: new_val = result.get(i, j) + val_a * other.data[(k, j)] result.set(i, j, new_val) return result def transpose(self) -> "SparseMatrix": result = SparseMatrix(self.cols, self.rows) for (row, col), val in self.data.items(): result.data[(col, row)] = val return result def __repr__(self) -> str: return f"SparseMatrix({self.rows}x{self.cols}, nnz={len(self.data)})"
Solution 2: Compressed Sparse Row (CSR) format
Compressed Sparse Row (CSR) format using three arrays: values, column indices, and row pointers. More memory-efficient and cache-friendly for row-based operations.
"""Sparse Matrix - Solution 2: Compressed Sparse Row (CSR) format.More memory-efficient for row-based operations and matrix multiplication."""from dataclasses import dataclass, fieldfrom typing import List@dataclassclass SparseMatrixCSR: rows: int cols: int values: List[float] = field(default_factory=list) col_indices: List[int] = field(default_factory=list) row_ptrs: List[int] = field(default_factory=lambda: [0]) @classmethod def from_dict(cls, rows: int, cols: int, data: dict) -> "SparseMatrixCSR": matrix = cls(rows, cols) sorted_entries = sorted(data.items(), key=lambda x: (x[0][0], x[0][1])) current_row = 0 for (r, c), v in sorted_entries: while current_row < r: matrix.row_ptrs.append(len(matrix.values)) current_row += 1 matrix.values.append(v) matrix.col_indices.append(c) while len(matrix.row_ptrs) <= rows: matrix.row_ptrs.append(len(matrix.values)) return matrix def get(self, row: int, col: int) -> float: if not (0 <= row < self.rows and 0 <= col < self.cols): raise IndexError(f"Index ({row}, {col}) out of bounds") start, end = self.row_ptrs[row], self.row_ptrs[row + 1] for i in range(start, end): if self.col_indices[i] == col: return self.values[i] return 0.0 def to_dict(self) -> dict: result = {} for row in range(self.rows): start, end = self.row_ptrs[row], self.row_ptrs[row + 1] for i in range(start, end): result[(row, self.col_indices[i])] = self.values[i] return result def add(self, other: "SparseMatrixCSR") -> "SparseMatrixCSR": if self.rows != other.rows or self.cols != other.cols: raise ValueError("Matrix dimensions must match") combined = self.to_dict() for (r, c), v in other.to_dict().items(): combined[(r, c)] = combined.get((r, c), 0.0) + v combined = {k: v for k, v in combined.items() if v != 0.0} return SparseMatrixCSR.from_dict(self.rows, self.cols, combined) def multiply(self, other: "SparseMatrixCSR") -> "SparseMatrixCSR": if self.cols != other.rows: raise ValueError("Incompatible dimensions for multiplication") result_dict = {} for i in range(self.rows): start_a, end_a = self.row_ptrs[i], self.row_ptrs[i + 1] for idx_a in range(start_a, end_a): k, val_a = self.col_indices[idx_a], self.values[idx_a] start_b, end_b = other.row_ptrs[k], other.row_ptrs[k + 1] for idx_b in range(start_b, end_b): j, val_b = other.col_indices[idx_b], other.values[idx_b] result_dict[(i, j)] = result_dict.get((i, j), 0.0) + val_a * val_b result_dict = {k: v for k, v in result_dict.items() if v != 0.0} return SparseMatrixCSR.from_dict(self.rows, other.cols, result_dict) def transpose(self) -> "SparseMatrixCSR": transposed = {(c, r): v for (r, c), v in self.to_dict().items()} return SparseMatrixCSR.from_dict(self.cols, self.rows, transposed)
Solution 3: Row-based dictionary representation
Row-based dictionary representation using {row: {col: value}}. Balances DOK simplicity with efficient row access patterns, good for multiplication.
"""Sparse Matrix - Solution 3: Row-based dictionary representation.Balances simplicity of DOK with better row-access patterns."""from dataclasses import dataclass, fieldfrom typing import Dict@dataclassclass SparseMatrixRowDict: rows: int cols: int data: Dict[int, Dict[int, float]] = field(default_factory=dict) def get(self, row: int, col: int) -> float: if not (0 <= row < self.rows and 0 <= col < self.cols): raise IndexError(f"Index ({row}, {col}) out of bounds") return self.data.get(row, {}).get(col, 0.0) def set(self, row: int, col: int, value: float) -> None: if not (0 <= row < self.rows and 0 <= col < self.cols): raise IndexError(f"Index ({row}, {col}) out of bounds") if value == 0.0: if row in self.data: self.data[row].pop(col, None) if not self.data[row]: del self.data[row] else: if row not in self.data: self.data[row] = {} self.data[row][col] = value def add(self, other: "SparseMatrixRowDict") -> "SparseMatrixRowDict": if self.rows != other.rows or self.cols != other.cols: raise ValueError("Matrix dimensions must match") result = SparseMatrixRowDict(self.rows, self.cols) all_rows = set(self.data.keys()) | set(other.data.keys()) for row in all_rows: self_row = self.data.get(row, {}) other_row = other.data.get(row, {}) all_cols = set(self_row.keys()) | set(other_row.keys()) for col in all_cols: val = self_row.get(col, 0.0) + other_row.get(col, 0.0) if val != 0.0: result.set(row, col, val) return result def multiply(self, other: "SparseMatrixRowDict") -> "SparseMatrixRowDict": if self.cols != other.rows: raise ValueError("Incompatible dimensions") result = SparseMatrixRowDict(self.rows, other.cols) for i, row_data in self.data.items(): for k, val_a in row_data.items(): if k in other.data: for j, val_b in other.data[k].items(): curr = result.get(i, j) result.set(i, j, curr + val_a * val_b) return result def transpose(self) -> "SparseMatrixRowDict": result = SparseMatrixRowDict(self.cols, self.rows) for row, row_data in self.data.items(): for col, val in row_data.items(): result.set(col, row, val) return result def get_row(self, row: int) -> Dict[int, float]: """Efficient row access - returns column->value mapping.""" if not (0 <= row < self.rows): raise IndexError(f"Row {row} out of bounds") return dict(self.data.get(row, {})) def nnz(self) -> int: """Count of non-zero elements.""" return sum(len(row) for row in self.data.values())
Question 21 - A* Pathfinding
Difficulty: 7 / 10
Approximate lines of code: 80 LoC
Tags: algorithms
Description
A* is the standard algorithm for finding the shortest path on a weighted graph, combining Dijkstra’s guaranteed optimality with a heuristic that guides the search toward the goal. It maintains a priority queue ordered by f(n) = g(n) + h(n), where g(n) is the actual cost from start and h(n) is the estimated cost to goal. The heuristic must be admissible (never overestimate) to guarantee the optimal path.
For grid-based pathfinding, common heuristics are Manhattan distance (for 4-directional movement) and Octile/Chebyshev distance (for 8-directional). The algorithm explores nodes in order of their f-score, updating path costs when a shorter route is found. Key data structures: a min-heap for the open set and a hash map for tracking best g-scores.
Part A: Basic A*
Problem: Part A
Implement A* on a 2D grid with 4-directional movement. Grid cells are either open (0) or blocked (1). Return the path from start to goal, or None if no path exists.
What heuristic for 4-directional vs 8-directional?
Why can’t the heuristic overestimate?
Common mistakes:
Using wrong heuristic for movement type (Manhattan for diagonals overestimates)
Not including start in the reconstructed path
Re-exploring nodes already in closed set (inefficient but correct)
Using Euclidean heuristic for 4-directional (inadmissible - can overestimate)
Code solutions
Code solutions
Solutions Overview
Solution 1 is a classic A* with dataclass nodes and Manhattan heuristic for 4-directional movement. Solution 2 adds a Grid class abstraction with 8-directional movement and Octile heuristic. Solution 3 takes a functional approach with weighted terrain support and returns detailed PathResult stats. These vary in movement direction support and terrain cost handling. Core techniques: priority queue (min-heap), f(n) = g(n) + h(n) scoring, admissible heuristics, path reconstruction.
Solution 1: Classic implementation
Classic implementation using a priority queue with dataclass nodes. Uses Manhattan heuristic for 4-directional. Returns path list or None. Clean, minimal code.
"""A* Pathfinding - Solution 1: Classic Implementation with Priority Queue"""from dataclasses import dataclass, fieldfrom heapq import heappush, heappopfrom typing import Callable, Optionalimport math@dataclass(frozen=True)class Point: x: int y: int def neighbors(self) -> list["Point"]: return [Point(self.x + dx, self.y + dy) for dx, dy in [(0, 1), (1, 0), (0, -1), (-1, 0)]]@dataclass(order=True)class Node: f_score: float point: Point = field(compare=False) g_score: float = field(compare=False)def manhattan(a: Point, b: Point) -> float: return abs(a.x - b.x) + abs(a.y - b.y)def euclidean(a: Point, b: Point) -> float: return math.sqrt((a.x - b.x) ** 2 + (a.y - b.y) ** 2)def astar( grid: list[list[int]], start: Point, goal: Point, heuristic: Callable[[Point, Point], float] = manhattan,) -> Optional[list[Point]]: """Find shortest path using A* algorithm. Grid: 0=open, 1=obstacle.""" if not grid or not grid[0]: return None rows, cols = len(grid), len(grid[0]) def is_valid(p: Point) -> bool: return 0 <= p.x < rows and 0 <= p.y < cols and grid[p.x][p.y] == 0 if not is_valid(start) or not is_valid(goal): return None open_set: list[Node] = [] heappush(open_set, Node(heuristic(start, goal), start, 0)) came_from: dict[Point, Point] = {} g_scores: dict[Point, float] = {start: 0} while open_set: current = heappop(open_set) if current.point == goal: path = [goal] while path[-1] in came_from: path.append(came_from[path[-1]]) return path[::-1] for neighbor in current.point.neighbors(): if not is_valid(neighbor): continue tentative_g = current.g_score + 1 if tentative_g < g_scores.get(neighbor, float("inf")): came_from[neighbor] = current.point g_scores[neighbor] = tentative_g f_score = tentative_g + heuristic(neighbor, goal) heappush(open_set, Node(f_score, neighbor, tentative_g)) return None
Solution 2: Grid class with 8-directional movement
Grid-based with 8-directional movement support. Uses Octile heuristic for diagonals. Diagonal moves cost sqrt(2). Includes Grid class for cleaner abstractions.
"""A* Pathfinding - Solution 2: Using a Grid class with 8-directional movement"""from dataclasses import dataclassfrom heapq import heappush, heappopfrom typing import Optionalimport mathCARDINAL = [(0, 1), (1, 0), (0, -1), (-1, 0)]DIAGONAL = [(1, 1), (1, -1), (-1, 1), (-1, -1)]@dataclass(frozen=True)class Cell: row: int col: int@dataclassclass Grid: data: list[list[int]] # 0=walkable, 1=blocked @property def rows(self) -> int: return len(self.data) @property def cols(self) -> int: return len(self.data[0]) if self.data else 0 def walkable(self, c: Cell) -> bool: return 0 <= c.row < self.rows and 0 <= c.col < self.cols and self.data[c.row][c.col] == 0 def neighbors(self, c: Cell, allow_diagonal: bool = False) -> list[tuple[Cell, float]]: dirs = CARDINAL + DIAGONAL if allow_diagonal else CARDINAL result = [] for dr, dc in dirs: neighbor = Cell(c.row + dr, c.col + dc) if self.walkable(neighbor): cost = math.sqrt(2) if (dr != 0 and dc != 0) else 1.0 result.append((neighbor, cost)) return resultdef chebyshev(a: Cell, b: Cell) -> float: return max(abs(a.row - b.row), abs(a.col - b.col))def octile(a: Cell, b: Cell) -> float: dx, dy = abs(a.row - b.row), abs(a.col - b.col) return max(dx, dy) + (math.sqrt(2) - 1) * min(dx, dy)def astar_grid( grid: Grid, start: Cell, goal: Cell, allow_diagonal: bool = True) -> Optional[list[Cell]]: """A* with optional diagonal movement. Returns path or None.""" if not grid.walkable(start) or not grid.walkable(goal): return None heuristic = octile if allow_diagonal else lambda a, b: abs(a.row - b.row) + abs(a.col - b.col) open_set: list[tuple[float, int, Cell]] = [] counter = 0 heappush(open_set, (heuristic(start, goal), counter, start)) came_from: dict[Cell, Cell] = {} g_score: dict[Cell, float] = {start: 0} while open_set: _, _, current = heappop(open_set) if current == goal: path = [] while current in came_from: path.append(current) current = came_from[current] path.append(start) return path[::-1] if g_score.get(current, float("inf")) < g_score[current]: continue # Skip outdated entries for neighbor, cost in grid.neighbors(current, allow_diagonal): tentative = g_score[current] + cost if tentative < g_score.get(neighbor, float("inf")): came_from[neighbor] = current g_score[neighbor] = tentative counter += 1 heappush(open_set, (tentative + heuristic(neighbor, goal), counter, neighbor)) return None
Solution 3: Functional approach with weighted terrain
Functional approach with weighted terrain support. Uses tuple coordinates instead of dataclass. Separates obstacles from weights. Returns PathResult with cost and nodes explored.
"""A* Pathfinding - Solution 3: Functional approach with weighted edges"""from dataclasses import dataclassfrom heapq import heappush, heappopfrom typing import Callable, Optional, TypeAliasCoord: TypeAlias = tuple[int, int]Heuristic: TypeAlias = Callable[[Coord, Coord], float]@dataclassclass PathResult: path: list[Coord] cost: float nodes_explored: intdef manhattan(a: Coord, b: Coord) -> float: return abs(a[0] - b[0]) + abs(a[1] - b[1])def find_path( width: int, height: int, obstacles: set[Coord], weights: dict[Coord, float], start: Coord, goal: Coord, heuristic: Heuristic = manhattan,) -> Optional[PathResult]: """ A* on a grid with weighted cells. obstacles: impassable cells weights: cost to enter cell (default 1.0) """ def in_bounds(c: Coord) -> bool: return 0 <= c[0] < height and 0 <= c[1] < width def passable(c: Coord) -> bool: return c not in obstacles def get_neighbors(c: Coord) -> list[Coord]: r, col = c candidates = [(r - 1, col), (r + 1, col), (r, col - 1), (r, col + 1)] return [n for n in candidates if in_bounds(n) and passable(n)] def edge_cost(to_node: Coord) -> float: return weights.get(to_node, 1.0) if not in_bounds(start) or not in_bounds(goal): return None if start in obstacles or goal in obstacles: return None open_heap: list[tuple[float, Coord]] = [(heuristic(start, goal), start)] came_from: dict[Coord, Coord] = {} g_score: dict[Coord, float] = {start: 0} explored = 0 while open_heap: _, current = heappop(open_heap) explored += 1 if current == goal: path = [current] while path[-1] != start: path.append(came_from[path[-1]]) return PathResult(path[::-1], g_score[goal], explored) for neighbor in get_neighbors(current): new_g = g_score[current] + edge_cost(neighbor) if new_g < g_score.get(neighbor, float("inf")): came_from[neighbor] = current g_score[neighbor] = new_g f = new_g + heuristic(neighbor, goal) heappush(open_heap, (f, neighbor)) return None
Question 22 - HyperLogLog
Difficulty: 9 / 10
Approximate lines of code: 100 LoC
Tags: probabilistic, data-structures
Description
HyperLogLog (HLL) is a probabilistic data structure for estimating the cardinality (count of unique elements) of a set using minimal memory. Redis uses it for PFCOUNT. The key insight: if you hash values uniformly, the probability of seeing a hash with N leading zeros is 1/2^N. So the maximum leading zeros you’ve observed gives you a rough estimate of log2(cardinality). With 10 leading zeros max, you’ve probably seen around 2^10 = 1024 unique elements.
The core data structure is an array of “registers” (typically 2^p registers, where p=14 gives 16384 registers). Each register stores the maximum leading-zeros-plus-one (called rho) seen for hashes that map to that bucket. To add an element: hash it, use the first p bits to pick a register, count leading zeros in the remaining bits, update the register if this count is higher. To estimate cardinality: compute the harmonic mean of 2^(-register) across all registers and apply bias correction.
Part A: Basic Structure
Problem: Part A
Implement add(element) and estimate(). Use a 64-bit hash. First p bits select the bucket (register index). Remaining 64-p bits are used to count leading zeros.
Single-bucket estimates are noisy. Split the hash space into m buckets and combine estimates using harmonic mean. Apply the alpha correction factor based on bucket count.
def estimate(self) -> int: m = len(self.registers) # 2^precision # Harmonic mean of 2^(-register_value) harmonic_sum = sum(2.0 ** (-r) for r in self.registers) alpha = self._get_alpha(m) # Bias correction: ~0.7213 for large m raw_estimate = alpha * m * m / harmonic_sum return int(raw_estimate)def _get_alpha(self, m: int) -> float: if m == 16: return 0.673 if m == 32: return 0.697 if m == 64: return 0.709 return 0.7213 / (1 + 1.079 / m)
Part C: Merge and Edge Cases
Problem: Part C
Implement merge(other) that combines two HLLs - this enables distributed counting. Handle small cardinality correction (linear counting when many registers are zero) and optionally large cardinality correction.
# Merge is simple: take max of each registerdef merge(self, other: HyperLogLog) -> HyperLogLog: result = HyperLogLog(precision=self.precision) result.registers = [max(a, b) for a, b in zip(self.registers, other.registers)] return result# Small range correction in estimate():if raw_estimate <= 2.5 * m: zeros = self.registers.count(0) if zeros > 0: # Linear counting: more accurate for small cardinalities return int(m * math.log(m / zeros))
hll_west = HyperLogLog(precision=12)hll_east = HyperLogLog(precision=12)# Add users from west coast serverfor user in west_coast_users: hll_west.add(user)# Add users from east coast serverfor user in east_coast_users: hll_east.add(user)# Merge to get total unique users (handles duplicates!)combined = hll_west.merge(hll_east)total_unique = combined.estimate()
Interview comments
Interview comments
Edge cases to probe:
What if you add the same element 1000 times?
Why is it called “leading zeros plus one” (the rho function)?
What happens when precision is very low (4 bits = 16 registers)?
How do you handle an empty HLL (all registers zero)?
Common mistakes:
Off-by-one in rho function (should be leading zeros + 1, not just leading zeros)
Using low bits instead of high bits for bucket index (high bits have better distribution)
Using arithmetic mean instead of harmonic mean
Forgetting small range correction (linear counting)
Using Python’s hash() which isn’t stable across runs
Code solutions
Code solutions
Solution 1 is a standard HyperLogLog with SHA-256 hashing, harmonic mean estimation, and both small-range and large-range corrections. Solution 2 uses Python’s built-in hash with FNV-style bit mixing for faster hashing and adds batch processing and merge_all class method. Solution 3 is a production-quality HyperLogLog++ implementation with memory-efficient bytearray storage and serialization support. These vary in their hashing approach and additional features for production use. Core techniques: probabilistic counting, leading zeros estimation, harmonic mean, bias correction.
Solution 1: Standard HyperLogLog with SHA-256
Standard HyperLogLog with SHA-256 hashing, harmonic mean estimation, and both small-range (linear counting) and large-range corrections. Clean implementation using dataclass with configurable precision (4-16 bits).
"""HyperLogLog Implementation - Basic VersionUses standard HyperLogLog algorithm with bias correction."""from dataclasses import dataclass, fieldimport hashlibimport math@dataclassclass HyperLogLog: """HyperLogLog cardinality estimator with configurable precision.""" precision: int = 14 # Number of bits for bucket index (4-16 typical) registers: list[int] = field(default_factory=list, repr=False) def __post_init__(self) -> None: if not 4 <= self.precision <= 16: raise ValueError("Precision must be between 4 and 16") num_registers = 1 << self.precision # 2^precision buckets if not self.registers: self.registers = [0] * num_registers def _hash(self, element: str) -> int: """Hash element to 64-bit integer.""" h = hashlib.sha256(element.encode()).hexdigest() return int(h[:16], 16) # Use first 64 bits def _count_leading_zeros(self, value: int, max_bits: int) -> int: """Count leading zeros in the remaining bits after bucket selection.""" if value == 0: return max_bits count = 0 for i in range(max_bits - 1, -1, -1): if value & (1 << i): break count += 1 return count def add(self, element: str) -> None: """Add an element to the HyperLogLog.""" hash_val = self._hash(element) # Use first 'precision' bits to determine bucket bucket_idx = hash_val >> (64 - self.precision) # Count leading zeros in remaining bits (+ 1 for the implicit leading 1) remaining = hash_val & ((1 << (64 - self.precision)) - 1) zeros = self._count_leading_zeros(remaining, 64 - self.precision) + 1 # Keep maximum zeros seen for this bucket self.registers[bucket_idx] = max(self.registers[bucket_idx], zeros) def estimate(self) -> int: """Estimate the cardinality using harmonic mean.""" m = len(self.registers) # Harmonic mean of 2^(-register_value) harmonic_sum = sum(2.0 ** (-r) for r in self.registers) alpha = self._get_alpha(m) raw_estimate = alpha * m * m / harmonic_sum # Small range correction (linear counting) if raw_estimate <= 2.5 * m: zeros = self.registers.count(0) if zeros > 0: return int(m * math.log(m / zeros)) # Large range correction (for 64-bit hash) if raw_estimate > (1 << 32) / 30: return int(-(1 << 64) * math.log(1 - raw_estimate / (1 << 64))) return int(raw_estimate) def _get_alpha(self, m: int) -> float: """Get bias correction constant alpha_m.""" if m == 16: return 0.673 if m == 32: return 0.697 if m == 64: return 0.709 return 0.7213 / (1 + 1.079 / m) def merge(self, other: "HyperLogLog") -> "HyperLogLog": """Merge with another HyperLogLog (must have same precision).""" if self.precision != other.precision: raise ValueError("Cannot merge HyperLogLogs with different precision") merged = HyperLogLog(precision=self.precision) merged.registers = [max(a, b) for a, b in zip(self.registers, other.registers)] return merged
Solution 2: FNV-style hashing with batch processing
Uses Python’s built-in hash with FNV-style bit mixing for faster hashing. Adds add_batch() for streaming, merge_all() class method for combining multiple HLLs, and error_rate() to report theoretical standard error (1.04/sqrt(m)).
"""HyperLogLog Implementation - Using mmh3 (MurmurHash3)Cleaner implementation with explicit register updates and streaming support."""from dataclasses import dataclass, fieldfrom typing import Iteratorimport math@dataclassclass HyperLogLog: """HyperLogLog using built-in hash function.""" precision: int = 12 _registers: list[int] = field(default_factory=list, init=False) def __post_init__(self) -> None: self._num_buckets = 1 << self.precision self._registers = [0] * self._num_buckets # Alpha correction factor self._alpha = 0.7213 / (1 + 1.079 / self._num_buckets) def _hash(self, item: str) -> int: """Generate 64-bit hash using Python's built-in hash with mixing.""" # Use mixing function to get better distribution h = hash(item) # Mix bits using FNV-style constants h = ((h ^ (h >> 33)) * 0xFF51AFD7ED558CCD) & 0xFFFFFFFFFFFFFFFF h = ((h ^ (h >> 33)) * 0xC4CEB9FE1A85EC53) & 0xFFFFFFFFFFFFFFFF return h ^ (h >> 33) def _rho(self, w: int) -> int: """ Position of leftmost 1-bit (1-indexed). This is equivalent to counting leading zeros + 1. """ if w == 0: return 64 - self.precision pos = 1 while (w & 1) == 0: w >>= 1 pos += 1 return pos def add(self, item: str) -> None: """Add an item to the sketch.""" x = self._hash(item) # First p bits determine the bucket j = x >> (64 - self.precision) # Remaining bits used for leading zeros count w = x & ((1 << (64 - self.precision)) - 1) self._registers[j] = max(self._registers[j], self._rho(w)) def add_batch(self, items: Iterator[str]) -> None: """Add multiple items efficiently.""" for item in items: self.add(item) def count(self) -> int: """Return estimated cardinality.""" m = self._num_buckets # Raw harmonic mean estimator z = sum(2.0 ** (-reg) for reg in self._registers) e = self._alpha * m * m / z # Small range correction using linear counting if e <= 2.5 * m: v = self._registers.count(0) if v > 0: return int(m * math.log(m / v)) return int(e) def merge_with(self, other: "HyperLogLog") -> None: """Merge another HyperLogLog into this one (in-place).""" if self.precision != other.precision: raise ValueError(f"Precision mismatch: {self.precision} vs {other.precision}") for i in range(self._num_buckets): self._registers[i] = max(self._registers[i], other._registers[i]) @classmethod def merge_all(cls, hlls: list["HyperLogLog"]) -> "HyperLogLog": """Merge multiple HyperLogLogs into a new one.""" if not hlls: raise ValueError("Cannot merge empty list") precision = hlls[0].precision if any(h.precision != precision for h in hlls): raise ValueError("All HyperLogLogs must have same precision") result = cls(precision=precision) for hll in hlls: result.merge_with(hll) return result def error_rate(self) -> float: """Expected standard error rate: 1.04 / sqrt(m).""" return 1.04 / math.sqrt(self._num_buckets)
Solution 3: Production-quality HyperLogLog++
Production-quality HyperLogLog++ implementation with 64-bit hashes, memory-efficient bytearray storage (1 byte per register), and serialization support. Includes custom xxHash-style mixing function for fast, high-quality hashing.
"""HyperLogLog Implementation - Production-Quality VersionIncludes HyperLogLog++ improvements: 64-bit hashes and bias correction."""from dataclasses import dataclassimport structimport mathdef _hash64(data: bytes) -> int: """64-bit hash using xxHash-style mixing.""" PRIME1 = 0x9E3779B185EBCA87 PRIME2 = 0xC2B2AE3D27D4EB4F PRIME3 = 0x165667B19E3779F9 h = len(data) * PRIME3 for i, byte in enumerate(data): h ^= byte * PRIME1 h = ((h << 31) | (h >> 33)) & 0xFFFFFFFFFFFFFFFF h = (h * PRIME2) & 0xFFFFFFFFFFFFFFFF # Final avalanche h ^= h >> 33 h = (h * PRIME2) & 0xFFFFFFFFFFFFFFFF h ^= h >> 29 h = (h * PRIME3) & 0xFFFFFFFFFFFFFFFF h ^= h >> 32 return h@dataclassclass HyperLogLogPlusPlus: """ HyperLogLog++ implementation with improved accuracy. Reference: Heule, Nunkesser, Hall (2013) """ precision: int = 14 # p: number of bits for indexing (4-18) def __post_init__(self) -> None: if not 4 <= self.precision <= 18: raise ValueError("Precision must be between 4 and 18") self.m = 1 << self.precision # Number of registers self.registers = bytearray(self.m) # Memory efficient: 1 byte per register self._alpha = self._compute_alpha() def _compute_alpha(self) -> float: """Compute bias correction constant.""" m = self.m if m >= 128: return 0.7213 / (1.0 + 1.079 / m) if m == 64: return 0.709 if m == 32: return 0.697 return 0.673 def _leading_zeros(self, val: int, bits: int) -> int: """Count leading zeros plus 1 (the rho function).""" if val == 0: return bits + 1 count = 1 mask = 1 << (bits - 1) while not (val & mask): count += 1 mask >>= 1 return count def add(self, element: str | bytes) -> None: """Add element to the HyperLogLog.""" if isinstance(element, str): element = element.encode('utf-8') h = _hash64(element) # Extract bucket index from high bits idx = h >> (64 - self.precision) # Compute rho on remaining bits w = h & ((1 << (64 - self.precision)) - 1) rho = self._leading_zeros(w, 64 - self.precision) if rho > self.registers[idx]: self.registers[idx] = rho def cardinality(self) -> int: """Estimate cardinality with bias correction.""" # Compute indicator function Z (harmonic mean) z_inv = sum(2.0 ** (-r) for r in self.registers) e = self._alpha * self.m * self.m / z_inv # Apply corrections based on estimate range if e <= 5 * self.m: # Small range: use linear counting if there are zeros v = self.registers.count(0) if v != 0: e_prime = self.m * math.log(self.m / v) return int(e_prime) return int(e) def merge(self, other: "HyperLogLogPlusPlus") -> "HyperLogLogPlusPlus": """Create new HLL from merging two instances.""" if self.precision != other.precision: raise ValueError("Precision mismatch") result = HyperLogLogPlusPlus(precision=self.precision) for i in range(self.m): result.registers[i] = max(self.registers[i], other.registers[i]) return result def serialize(self) -> bytes: """Serialize to bytes for storage/transmission.""" header = struct.pack('!B', self.precision) return header + bytes(self.registers) @classmethod def deserialize(cls, data: bytes) -> "HyperLogLogPlusPlus": """Deserialize from bytes.""" precision = struct.unpack('!B', data[:1])[0] hll = cls(precision=precision) hll.registers = bytearray(data[1:]) return hll def memory_bytes(self) -> int: """Return memory usage in bytes.""" return self.m # 1 byte per register def standard_error(self) -> float: """Theoretical standard error: 1.04/sqrt(m).""" return 1.04 / math.sqrt(self.m)
Question 23 - Collision Detection
Difficulty: 3 / 10
Approximate lines of code: 80 LoC
Tags: game/simulation, algorithms
Description
Collision detection determines whether two geometric shapes overlap. In 2D games and physics engines, the most common shapes are axis-aligned bounding boxes (AABBs) and circles. An AABB is defined by its min/max corners (min_x, min_y, max_x, max_y), while a circle is defined by center (x, y) and radius. The key insight for efficiency is to avoid expensive operations like square roots when possible - comparing squared distances works for most circle checks.
For large numbers of objects, a two-phase approach is used: broad-phase (spatial partitioning like grids or quadtrees) quickly eliminates pairs that can’t possibly collide, then narrow-phase performs precise geometric checks on remaining candidates.
Part A: AABB vs AABB Collision
Problem: Part A
Two AABBs collide if and only if they overlap on both axes. Check if one box is completely to the left, right, above, or below the other - if none of these are true, they collide.
Two circles collide if the distance between their centers is less than or equal to the sum of their radii. Avoid sqrt by comparing squared values: dx^2 + dy^2 <= (r1 + r2)^2.
Find the point on the AABB closest to the circle’s center (clamp the center coordinates to the box bounds). Then check if that closest point is within the circle’s radius.
box = AABB(min_x=0, min_y=0, max_x=10, max_y=10)# Circle inside boxcircle_inside = Circle(x=5, y=5, radius=2)aabb_vs_circle(box, circle_inside) # True# Circle outside, near cornercircle_corner = Circle(x=12, y=12, radius=3)# Closest point on box is (10, 10)# Distance to center: sqrt(4+4) = 2.83 < 3aabb_vs_circle(box, circle_corner) # True# Circle far from boxcircle_far = Circle(x=20, y=20, radius=3)# Closest point is (10, 10), distance = 14.14 > 3aabb_vs_circle(box, circle_far) # False
Interview comments
Interview comments
Edge cases to probe:
Do touching edges/surfaces count as collision?
What about zero-radius circles or zero-area boxes?
Circle center exactly on box edge?
How would you handle rotated rectangles (OBB)?
Common mistakes:
Using sqrt when squared comparison works (performance)
AABB-Circle: checking if center is inside, not closest point distance
Off-by-one with ⇐ vs < (do touching objects collide?)
Forgetting to handle the case where circle center is inside the box
Code solutions
Code solutions
Solution 1 provides basic narrow-phase collision detection with separate functions for AABB-AABB, circle-circle, and AABB-circle checks. Solution 2 adds a uniform grid for broad-phase spatial partitioning, reducing the number of narrow-phase checks needed. Solution 3 uses a quadtree for adaptive spatial partitioning, which is more efficient than uniform grids when objects are non-uniformly distributed. The key difference is how they handle broad-phase culling: none vs uniform grid vs adaptive quadtree. Core techniques: AABB overlap tests, squared distance comparisons, spatial partitioning (grids, quadtrees).
Solution 1: Basic Collision Detection
A straightforward implementation with separate functions for each collision type. Uses dataclasses for shapes and avoids sqrt by comparing squared distances.
"""Solution 1: Basic Collision DetectionStraightforward implementation using dataclasses for shapes."""from dataclasses import dataclassimport math@dataclassclass AABB: """Axis-Aligned Bounding Box defined by min/max corners.""" min_x: float min_y: float max_x: float max_y: float@dataclassclass Circle: """Circle defined by center point and radius.""" x: float y: float radius: floatdef aabb_vs_aabb(a: AABB, b: AABB) -> bool: """Check if two AABBs overlap.""" if a.max_x < b.min_x or b.max_x < a.min_x: return False if a.max_y < b.min_y or b.max_y < a.min_y: return False return Truedef circle_vs_circle(a: Circle, b: Circle) -> bool: """Check if two circles overlap.""" dx = a.x - b.x dy = a.y - b.y distance_squared = dx * dx + dy * dy radius_sum = a.radius + b.radius return distance_squared <= radius_sum * radius_sumdef aabb_vs_circle(box: AABB, circle: Circle) -> bool: """Check if an AABB and circle overlap.""" # Find the closest point on the AABB to the circle center closest_x = max(box.min_x, min(circle.x, box.max_x)) closest_y = max(box.min_y, min(circle.y, box.max_y)) # Calculate distance from circle center to closest point dx = circle.x - closest_x dy = circle.y - closest_y distance_squared = dx * dx + dy * dy return distance_squared <= circle.radius * circle.radius
Solution 2: Collision Detection with Spatial Partitioning (Grid)
Adds a uniform grid for broad-phase spatial partitioning. Objects are inserted into grid cells based on their bounding box, and queries return only objects in relevant cells as collision candidates.
"""Solution 2: Collision Detection with Spatial Partitioning (Grid)Uses a uniform grid for efficient broad-phase collision detection."""from dataclasses import dataclass, fieldfrom typing import List, Set, Tuple, Unionimport math@dataclassclass AABB: min_x: float min_y: float max_x: float max_y: float@dataclassclass Circle: x: float y: float radius: float def to_aabb(self) -> AABB: """Convert circle to bounding AABB for broad-phase.""" return AABB( self.x - self.radius, self.y - self.radius, self.x + self.radius, self.y + self.radius )Shape = Union[AABB, Circle]@dataclassclass SpatialGrid: """Uniform grid for spatial partitioning.""" cell_size: float cells: dict = field(default_factory=dict) def _get_cell(self, x: float, y: float) -> Tuple[int, int]: return (int(x // self.cell_size), int(y // self.cell_size)) def _get_cells_for_aabb(self, aabb: AABB) -> List[Tuple[int, int]]: """Get all cells that an AABB occupies.""" min_cell = self._get_cell(aabb.min_x, aabb.min_y) max_cell = self._get_cell(aabb.max_x, aabb.max_y) cells = [] for cx in range(min_cell[0], max_cell[0] + 1): for cy in range(min_cell[1], max_cell[1] + 1): cells.append((cx, cy)) return cells def insert(self, obj_id: int, shape: Shape) -> None: """Insert an object into the grid.""" aabb = shape.to_aabb() if isinstance(shape, Circle) else shape for cell in self._get_cells_for_aabb(aabb): if cell not in self.cells: self.cells[cell] = set() self.cells[cell].add(obj_id) def query(self, shape: Shape) -> Set[int]: """Find all objects that might collide with the given shape.""" aabb = shape.to_aabb() if isinstance(shape, Circle) else shape candidates = set() for cell in self._get_cells_for_aabb(aabb): if cell in self.cells: candidates.update(self.cells[cell]) return candidates def clear(self) -> None: self.cells.clear()def check_collision(a: Shape, b: Shape) -> bool: """Check collision between any two shapes.""" if isinstance(a, AABB) and isinstance(b, AABB): return not (a.max_x < b.min_x or b.max_x < a.min_x or a.max_y < b.min_y or b.max_y < a.min_y) if isinstance(a, Circle) and isinstance(b, Circle): dx, dy = a.x - b.x, a.y - b.y return dx*dx + dy*dy <= (a.radius + b.radius) ** 2 # AABB vs Circle box, circle = (a, b) if isinstance(a, AABB) else (b, a) closest_x = max(box.min_x, min(circle.x, box.max_x)) closest_y = max(box.min_y, min(circle.y, box.max_y)) dx, dy = circle.x - closest_x, circle.y - closest_y return dx*dx + dy*dy <= circle.radius ** 2
Solution 3: Collision Detection with Quadtree Spatial Partitioning
Uses a quadtree for adaptive spatial partitioning, which is more efficient than uniform grids for non-uniform object distributions. Subdivides recursively when capacity is exceeded.
"""Solution 3: Collision Detection with Quadtree Spatial PartitioningMore efficient than uniform grid for non-uniform object distributions."""from dataclasses import dataclass, fieldfrom typing import List, Optional, Set, Union@dataclassclass AABB: min_x: float min_y: float max_x: float max_y: float def contains_point(self, x: float, y: float) -> bool: return self.min_x <= x <= self.max_x and self.min_y <= y <= self.max_y def intersects(self, other: "AABB") -> bool: return not (self.max_x < other.min_x or other.max_x < self.min_x or self.max_y < other.min_y or other.max_y < self.min_y)@dataclassclass Circle: x: float y: float radius: float def to_aabb(self) -> AABB: return AABB(self.x - self.radius, self.y - self.radius, self.x + self.radius, self.y + self.radius)Shape = Union[AABB, Circle]@dataclassclass QuadTree: boundary: AABB capacity: int = 4 objects: List[tuple] = field(default_factory=list) # (id, shape) divided: bool = False nw: Optional["QuadTree"] = None ne: Optional["QuadTree"] = None sw: Optional["QuadTree"] = None se: Optional["QuadTree"] = None def _subdivide(self) -> None: b = self.boundary mid_x, mid_y = (b.min_x + b.max_x) / 2, (b.min_y + b.max_y) / 2 self.nw = QuadTree(AABB(b.min_x, mid_y, mid_x, b.max_y), self.capacity) self.ne = QuadTree(AABB(mid_x, mid_y, b.max_x, b.max_y), self.capacity) self.sw = QuadTree(AABB(b.min_x, b.min_y, mid_x, mid_y), self.capacity) self.se = QuadTree(AABB(mid_x, b.min_y, b.max_x, mid_y), self.capacity) self.divided = True def insert(self, obj_id: int, shape: Shape) -> bool: aabb = shape.to_aabb() if isinstance(shape, Circle) else shape if not self.boundary.intersects(aabb): return False if len(self.objects) < self.capacity and not self.divided: self.objects.append((obj_id, shape)) return True if not self.divided: self._subdivide() for child in (self.nw, self.ne, self.sw, self.se): child.insert(obj_id, shape) return True def query(self, shape: Shape) -> Set[int]: aabb = shape.to_aabb() if isinstance(shape, Circle) else shape found = set() if not self.boundary.intersects(aabb): return found for obj_id, _ in self.objects: found.add(obj_id) if self.divided: for child in (self.nw, self.ne, self.sw, self.se): found.update(child.query(shape)) return founddef narrow_phase(a: Shape, b: Shape) -> bool: """Precise collision check between two shapes.""" if isinstance(a, AABB) and isinstance(b, AABB): return a.intersects(b) if isinstance(a, Circle) and isinstance(b, Circle): dx, dy = a.x - b.x, a.y - b.y return dx*dx + dy*dy <= (a.radius + b.radius) ** 2 box, circle = (a, b) if isinstance(a, AABB) else (b, a) cx = max(box.min_x, min(circle.x, box.max_x)) cy = max(box.min_y, min(circle.y, box.max_y)) dx, dy = circle.x - cx, circle.y - cy return dx*dx + dy*dy <= circle.radius ** 2
Question 24 - Job Scheduler
Difficulty: 8 / 10
Approximate lines of code: 110 LoC
Tags: scheduling
Description
A job scheduler executes tasks based on priority while respecting dependency constraints. Jobs with higher priority run first, but only after all their dependencies have completed. The core data structures are a min-heap (priority queue) for O(log n) selection of the highest-priority ready job, and a graph of dependencies to track which jobs are blocked.
Internally, each job has a status (pending, running, completed, cancelled), a priority, and a set of dependency job IDs. When a job completes, you notify all jobs that depend on it. When all dependencies are satisfied, the job becomes ready and enters the priority queue.
Note: The sortedcontainers library (SortedList) is available. Unlike heapq, SortedList supports O(log n) removal of cancelled jobs instead of lazy deletion.
Part A: Priority Queue
Problem: Part A
Implement a scheduler where jobs execute in priority order. Higher priority means the job runs sooner. Use a max-heap (negate priorities in a min-heap) to efficiently select the next job.
Jobs can depend on other jobs. A job only becomes ready when all its dependencies have completed. Track a waiting_on count that decrements as dependencies finish.
scheduler = JobScheduler()scheduler.add_job("compile", task=compile_fn, priority=2)scheduler.add_job("test", task=test_fn, priority=5, dependencies={"compile"})scheduler.add_job("lint", task=lint_fn, priority=3)# "test" has priority 5, but depends on "compile"# "lint" has priority 3 with no dependencies# Internal state:# heap: [(-3, "lint"), (-2, "compile")] # "test" not in heap yet# waiting_on: {"test": 1, "compile": 0, "lint": 0}# dependents: {"compile": ["test"]}scheduler.run_next() # Runs "lint" (priority 3, ready)scheduler.run_next() # Runs "compile" (priority 2, ready) # -> "test" now has waiting_on=0, enters heapscheduler.run_next() # Runs "test" (priority 5, now ready)
Part C: Cancellation
Problem: Part C
Add cancel(job_id) that cancels a pending job. When a job is cancelled, all jobs that depend on it (directly or transitively) can never run and should also be marked as failed/cancelled.
What if a dependency doesn’t exist yet when you add a job? (Track as pending)
What happens if you cancel a job that’s already running? (Cannot cancel)
What if a job fails during execution? (Should cascade like cancellation)
What if you add a job with a dependency that already completed? (Should be immediately ready)
Common mistakes:
Re-sorting entire list instead of using a heap (O(n log n) vs O(log n))
Linear scan to find dependents when a job completes (maintain reverse edges)
Not skipping cancelled jobs when popping from heap
Forgetting to check if dependency already exists and is completed when adding job
Code solutions
Code solutions
Solutions Overview
Solution 1 uses a simple heap with set-based dependency tracking, re-scanning all jobs on each completion. Solution 2 adds explicit reverse edges (dependents list) with a waiting_on counter for O(1) readiness checks and cascade cancellation. Solution 3 extends to async-ready design with result storage and exception propagation to dependent jobs. These vary in their approach to dependency notification efficiency and failure handling. Core techniques: min-heap/priority queue, topological ordering, graph traversal.
Solution 1: Simple Heap-Based
Simple heap-based approach with sets for dependency tracking. Re-scans all pending jobs after each completion to find newly ready jobs. Clear and correct but O(n) on completion.
"""Job Scheduler - Solution 1: Simple Heap-Based ApproachUses a priority queue (heap) for scheduling. Dependencies tracked via sets.Jobs are processed when all dependencies are satisfied."""from dataclasses import dataclass, fieldfrom typing import Callable, Optionalfrom enum import Enumimport heapqclass JobStatus(Enum): PENDING = "pending" RUNNING = "running" COMPLETED = "completed" CANCELLED = "cancelled"@dataclass(order=True)class Job: priority: int job_id: str = field(compare=False) task: Callable[[], None] = field(compare=False) dependencies: set[str] = field(default_factory=set, compare=False) status: JobStatus = field(default=JobStatus.PENDING, compare=False)class JobScheduler: def __init__(self) -> None: self._jobs: dict[str, Job] = {} self._ready_queue: list[tuple[int, str]] = [] # (neg_priority, job_id) self._completed: set[str] = set() def add_job( self, job_id: str, task: Callable[[], None], priority: int = 0, dependencies: Optional[set[str]] = None, ) -> None: """Add a job to the scheduler.""" job = Job( priority=priority, job_id=job_id, task=task, dependencies=dependencies or set(), ) self._jobs[job_id] = job self._try_enqueue(job) def _try_enqueue(self, job: Job) -> None: """Enqueue job if all dependencies are satisfied.""" if job.status != JobStatus.PENDING: return if job.dependencies.issubset(self._completed): heapq.heappush(self._ready_queue, (-job.priority, job.job_id)) def cancel(self, job_id: str) -> bool: """Cancel a pending job. Returns True if cancelled.""" if job_id not in self._jobs: return False job = self._jobs[job_id] if job.status == JobStatus.PENDING: job.status = JobStatus.CANCELLED return True return False def run_next(self) -> Optional[str]: """Run the highest priority ready job. Returns job_id or None.""" while self._ready_queue: _, job_id = heapq.heappop(self._ready_queue) job = self._jobs[job_id] if job.status != JobStatus.PENDING: continue job.status = JobStatus.RUNNING job.task() job.status = JobStatus.COMPLETED self._completed.add(job_id) self._check_dependents() return job_id return None def _check_dependents(self) -> None: """Re-check all pending jobs for dependency satisfaction.""" for job in self._jobs.values(): if job.status == JobStatus.PENDING: self._try_enqueue(job) def run_all(self) -> list[str]: """Run all jobs in priority order. Returns list of completed job_ids.""" executed = [] while (job_id := self.run_next()) is not None: executed.append(job_id) return executed
Solution 2: Graph-Based with Reverse Edges
Graph-based approach with explicit reverse edge tracking (dependents list). Maintains waiting_on counter per job for O(1) readiness check. Includes cascade cancellation for dependent jobs.
"""Job Scheduler - Solution 2: Graph-Based with Topological AwarenessUses explicit dependency graph tracking. More efficient dependency resolutionby maintaining reverse edges (dependents list)."""from dataclasses import dataclass, fieldfrom typing import Callable, Optionalfrom enum import Enumimport heapqclass Status(Enum): PENDING = "pending" RUNNING = "running" DONE = "done" CANCELLED = "cancelled"@dataclassclass Job: id: str task: Callable[[], None] priority: int = 0 deps: set[str] = field(default_factory=set) status: Status = field(default=Status.PENDING) waiting_on: int = 0 # Count of unfinished dependenciesclass Scheduler: def __init__(self) -> None: self._jobs: dict[str, Job] = {} self._dependents: dict[str, list[str]] = {} # job_id -> jobs waiting on it self._heap: list[tuple[int, str]] = [] def add( self, job_id: str, task: Callable[[], None], priority: int = 0, deps: Optional[set[str]] = None, ) -> None: deps = deps or set() job = Job(id=job_id, task=task, priority=priority, deps=deps) self._jobs[job_id] = job # Count how many deps are not yet done for dep_id in deps: if dep_id in self._jobs and self._jobs[dep_id].status == Status.DONE: continue job.waiting_on += 1 self._dependents.setdefault(dep_id, []).append(job_id) if job.waiting_on == 0: heapq.heappush(self._heap, (-priority, job_id)) def cancel(self, job_id: str) -> bool: if job_id not in self._jobs: return False job = self._jobs[job_id] if job.status == Status.PENDING: job.status = Status.CANCELLED # Propagate cancellation: jobs depending on this can never run self._cascade_cancel(job_id) return True return False def _cascade_cancel(self, job_id: str) -> None: """Cancel all jobs that depend on a cancelled job.""" for dep_id in self._dependents.get(job_id, []): dep_job = self._jobs.get(dep_id) if dep_job and dep_job.status == Status.PENDING: dep_job.status = Status.CANCELLED self._cascade_cancel(dep_id) def _notify_dependents(self, job_id: str) -> None: """Notify dependents that this job completed.""" for dep_id in self._dependents.get(job_id, []): dep_job = self._jobs.get(dep_id) if dep_job and dep_job.status == Status.PENDING: dep_job.waiting_on -= 1 if dep_job.waiting_on == 0: heapq.heappush(self._heap, (-dep_job.priority, dep_job.id)) def step(self) -> Optional[str]: """Execute one job. Returns job_id or None.""" while self._heap: _, job_id = heapq.heappop(self._heap) job = self._jobs[job_id] if job.status != Status.PENDING: continue job.status = Status.RUNNING job.task() job.status = Status.DONE self._notify_dependents(job_id) return job_id return None def run_all(self) -> list[str]: order = [] while (jid := self.step()) is not None: order.append(jid) return order
Solution 3: Async-Ready with Result Storage
Async-ready design with job result storage and exception propagation. Jobs that fail cause all dependent jobs to be marked failed with the error. Supports retrieving job results after completion.
"""Job Scheduler - Solution 3: Async-Ready with CallbacksDesigned for extensibility to async execution. Uses callbacks for completionnotification and supports job result retrieval."""from dataclasses import dataclass, fieldfrom typing import Callable, Any, Optionalfrom enum import Enum, autoimport heapqclass State(Enum): WAITING = auto() READY = auto() RUNNING = auto() DONE = auto() FAILED = auto() CANCELLED = auto()@dataclassclass JobSpec: id: str fn: Callable[[], Any] priority: int = 0 deps: frozenset[str] = field(default_factory=frozenset) state: State = State.WAITING result: Any = None error: Optional[Exception] = Noneclass AsyncReadyScheduler: def __init__(self) -> None: self._jobs: dict[str, JobSpec] = {} self._waiting_for: dict[str, set[str]] = {} # job -> deps not yet done self._blocked_by: dict[str, set[str]] = {} # dep -> jobs waiting on it self._ready: list[tuple[int, str]] = [] def submit( self, job_id: str, fn: Callable[[], Any], priority: int = 0, deps: Optional[set[str]] = None, ) -> JobSpec: deps_frozen = frozenset(deps or []) job = JobSpec(id=job_id, fn=fn, priority=priority, deps=deps_frozen) self._jobs[job_id] = job # Track which deps are still pending pending_deps = {d for d in deps_frozen if d not in self._jobs or self._jobs[d].state != State.DONE} self._waiting_for[job_id] = pending_deps for dep in pending_deps: self._blocked_by.setdefault(dep, set()).add(job_id) if not pending_deps: job.state = State.READY heapq.heappush(self._ready, (-priority, job_id)) return job def cancel(self, job_id: str) -> bool: job = self._jobs.get(job_id) if not job or job.state not in (State.WAITING, State.READY): return False job.state = State.CANCELLED self._propagate_failure(job_id) return True def _propagate_failure(self, job_id: str) -> None: for blocked_id in self._blocked_by.get(job_id, []): blocked = self._jobs[blocked_id] if blocked.state in (State.WAITING, State.READY): blocked.state = State.FAILED blocked.error = Exception(f"Dependency {job_id} unavailable") self._propagate_failure(blocked_id) def _on_complete(self, job_id: str) -> None: for blocked_id in self._blocked_by.get(job_id, []): waiting = self._waiting_for.get(blocked_id, set()) waiting.discard(job_id) if not waiting: blocked = self._jobs[blocked_id] if blocked.state == State.WAITING: blocked.state = State.READY heapq.heappush(self._ready, (-blocked.priority, blocked_id)) def tick(self) -> Optional[str]: while self._ready: _, job_id = heapq.heappop(self._ready) job = self._jobs[job_id] if job.state != State.READY: continue job.state = State.RUNNING try: job.result = job.fn() job.state = State.DONE except Exception as e: job.state = State.FAILED job.error = e self._propagate_failure(job_id) continue self._on_complete(job_id) return job_id return None def drain(self) -> list[str]: completed = [] while (jid := self.tick()) is not None: completed.append(jid) return completed def get_result(self, job_id: str) -> Any: job = self._jobs.get(job_id) if not job or job.state != State.DONE: raise ValueError(f"Job {job_id} not completed") return job.result
A heartbeat monitor tracks the health of distributed nodes by receiving periodic heartbeat signals. If a node fails to send a heartbeat within a timeout period, it is marked as unhealthy. This pattern is fundamental to distributed systems for failure detection - used in Kubernetes pod health checks, ZooKeeper session management, and service mesh health monitoring.
The core data structures are: (1) a map from node ID to last heartbeat timestamp, (2) a map tracking current node states (healthy/unhealthy), and (3) optionally a min-heap of deadlines for efficient expiration checking. The key challenge is triggering state-change callbacks only on transitions (healthy→unhealthy or vice versa), not on every check.
Note: The sortedcontainers library (SortedList) is available. Unlike heapq, SortedList supports O(log n) removal when a node’s deadline is updated, avoiding stale heap entries.
Part A: Basic Health Tracking
Problem: Part A
Implement heartbeat(node_id) to record a heartbeat, and get_status(node_id) to return current health based on whether the last heartbeat was within the timeout.
Add callbacks that fire when a node’s state changes. The callback should only fire on transitions, not on every heartbeat or status check.
def on_state_change(node_id: str, new_state: NodeState): print(f"{node_id} is now {new_state}")monitor = HeartbeatMonitor( timeout_seconds=5.0, on_state_change=on_state_change)monitor.heartbeat("node-1") # First heartbeat, no callback (not a transition)monitor.heartbeat("node-1") # Still healthy, no callbacktime.sleep(6)monitor.check_nodes() # Prints: "node-1 is now UNHEALTHY"monitor.check_nodes() # No output (already unhealthy, not a transition)monitor.heartbeat("node-1") # Prints: "node-1 is now HEALTHY"
Part C: Efficient Expiration Checking
Problem: Part C
With many nodes, checking all timestamps on every call is O(n). Use a min-heap to efficiently find nodes whose deadlines have passed. Handle the problem of stale heap entries when a node’s deadline is updated.
monitor = HeartbeatMonitor(timeout_seconds=5.0)# Register 10000 nodesfor i in range(10000): monitor.heartbeat(f"node-{i}")# Only node-0 misses heartbeattime.sleep(6)for i in range(1, 10000): monitor.heartbeat(f"node-{i}")# Efficient: only processes expired entries from heapfailed = monitor.process_expirations() # Returns ["node-0"]# Internal heap state uses generation numbers:# When node-1 sends new heartbeat, generation increments# Old heap entry for node-1 is stale (generation mismatch), skipped
Interview comments
Interview comments
Edge cases to probe:
What state should a node be in after its first heartbeat? (Healthy, no transition callback)
What if check_nodes() is called and a node is already unhealthy?
How do you handle a node recovering (sending heartbeat after being marked unhealthy)?
What about nodes that have never sent a heartbeat?
Common mistakes:
Firing callback on every heartbeat, not just state transitions
O(n) scan of all nodes on every check instead of using a heap
Stale heap entries: node refreshes deadline but old entry still in heap causes false failure
Using time.time() instead of time.monotonic() (system clock can jump)
Holding lock while firing callbacks (deadlock if callback calls back into monitor)
Code solutions
Code solutions
Solution 1 uses a simple polling approach with dictionaries tracking last heartbeat times and current states. Solution 2 implements a timer-based approach with a min-heap of deadlines and generation numbers to handle stale entries. Solution 3 provides a thread-safe implementation with a background checker thread and proper lock handling. These vary in efficiency (O(n) polling vs O(log n) heap) and threading model (single-threaded vs background thread). Core techniques: timestamp comparison, min-heap with generation numbers, callbacks outside locks to prevent deadlock.
Solution 1: Simple polling approach
Simple polling approach using dictionaries. Tracks last heartbeat time and current state per node. check_nodes() iterates all nodes to find failures. Callbacks fire outside the main logic after state changes are recorded.
"""Heartbeat Monitor - Solution 1: Simple Polling-Based ApproachA straightforward implementation using a dictionary to track last heartbeat times.Periodically polls to check for failed nodes."""from dataclasses import dataclass, fieldfrom typing import Callablefrom enum import Enumimport timeclass NodeState(Enum): HEALTHY = "healthy" UNHEALTHY = "unhealthy"@dataclassclass HeartbeatMonitor: timeout_seconds: float = 5.0 on_state_change: Callable[[str, NodeState], None] | None = None _last_heartbeat: dict[str, float] = field(default_factory=dict) _node_states: dict[str, NodeState] = field(default_factory=dict) def heartbeat(self, node_id: str) -> None: """Record a heartbeat from a node.""" self._last_heartbeat[node_id] = time.time() previous_state = self._node_states.get(node_id) if previous_state != NodeState.HEALTHY: self._node_states[node_id] = NodeState.HEALTHY if self.on_state_change and previous_state is not None: self.on_state_change(node_id, NodeState.HEALTHY) elif previous_state is None: self._node_states[node_id] = NodeState.HEALTHY def check_nodes(self) -> list[str]: """Check all nodes and return list of failed node IDs.""" now = time.time() failed_nodes = [] for node_id, last_time in self._last_heartbeat.items(): elapsed = now - last_time previous_state = self._node_states.get(node_id) if elapsed > self.timeout_seconds: if previous_state != NodeState.UNHEALTHY: self._node_states[node_id] = NodeState.UNHEALTHY if self.on_state_change: self.on_state_change(node_id, NodeState.UNHEALTHY) failed_nodes.append(node_id) return failed_nodes def get_state(self, node_id: str) -> NodeState | None: """Get the current state of a node.""" return self._node_states.get(node_id) def get_all_nodes(self) -> dict[str, NodeState]: """Get states of all known nodes.""" return dict(self._node_states)
Solution 2: Timer-based with min-heap
Timer-based approach with a min-heap of deadlines. Each node has a generation number that increments on heartbeat. Stale heap entries (generation mismatch) are skipped during processing. Supports injectable time provider for testing.
"""Heartbeat Monitor - Solution 2: Timer-Based with Per-Node DeadlinesUses scheduled deadlines per node with a heap for efficient expiration processing.Handles stale heap entries by tracking generation numbers."""from dataclasses import dataclass, fieldfrom typing import Callable, Protocolfrom enum import Enumimport heapqimport timeclass NodeState(Enum): HEALTHY = "healthy" UNHEALTHY = "unhealthy"class TimeProvider(Protocol): def now(self) -> float: ...class RealTimeProvider: def now(self) -> float: return time.time()@dataclassclass HeapEntry: deadline: float generation: int node_id: str def __lt__(self, other: "HeapEntry") -> bool: return self.deadline < other.deadline@dataclassclass NodeInfo: node_id: str deadline: float generation: int state: NodeState = NodeState.HEALTHY@dataclassclass HeartbeatMonitor: timeout_seconds: float = 5.0 on_state_change: Callable[[str, NodeState], None] | None = None time_provider: TimeProvider = field(default_factory=RealTimeProvider) _nodes: dict[str, NodeInfo] = field(default_factory=dict) _deadline_heap: list[HeapEntry] = field(default_factory=list) def heartbeat(self, node_id: str) -> None: """Record a heartbeat from a node.""" now = self.time_provider.now() new_deadline = now + self.timeout_seconds if node_id in self._nodes: node = self._nodes[node_id] old_state = node.state node.deadline = new_deadline node.generation += 1 node.state = NodeState.HEALTHY heapq.heappush(self._deadline_heap, HeapEntry(new_deadline, node.generation, node_id)) if old_state == NodeState.UNHEALTHY and self.on_state_change: self.on_state_change(node_id, NodeState.HEALTHY) else: node = NodeInfo(node_id=node_id, deadline=new_deadline, generation=0) self._nodes[node_id] = node heapq.heappush(self._deadline_heap, HeapEntry(new_deadline, 0, node_id)) def process_expirations(self) -> list[str]: """Process expired deadlines and return list of newly failed nodes.""" now = self.time_provider.now() newly_failed = [] while self._deadline_heap and self._deadline_heap[0].deadline <= now: entry = heapq.heappop(self._deadline_heap) current_node = self._nodes.get(entry.node_id) # Skip stale entries (generation mismatch means deadline was updated) if current_node is None or current_node.generation != entry.generation: continue if current_node.state == NodeState.HEALTHY: current_node.state = NodeState.UNHEALTHY newly_failed.append(entry.node_id) if self.on_state_change: self.on_state_change(entry.node_id, NodeState.UNHEALTHY) return newly_failed def get_state(self, node_id: str) -> NodeState | None: return self._nodes[node_id].state if node_id in self._nodes else None def get_all_nodes(self) -> dict[str, NodeState]: return {nid: info.state for nid, info in self._nodes.items()}
Solution 3: Thread-safe with background checker
Production-oriented thread-safe implementation with a background checker thread. Uses locks for thread safety, fires callbacks outside the lock to prevent deadlocks. Includes start/stop lifecycle for the background thread.
"""Heartbeat Monitor - Solution 3: Thread-Safe with Background CheckerProduction-oriented implementation with background thread for automaticfailure detection and thread-safe operations."""from dataclasses import dataclass, fieldfrom typing import Callablefrom enum import Enumfrom threading import Lock, Thread, Eventimport timeclass NodeState(Enum): HEALTHY = "healthy" UNHEALTHY = "unhealthy"@dataclassclass Node: node_id: str last_heartbeat: float state: NodeState = NodeState.HEALTHY@dataclassclass HeartbeatMonitor: timeout_seconds: float = 5.0 check_interval: float = 1.0 on_state_change: Callable[[str, NodeState], None] | None = None _nodes: dict[str, Node] = field(default_factory=dict) _lock: Lock = field(default_factory=Lock) _stop_event: Event = field(default_factory=Event) _checker_thread: Thread | None = field(default=None, init=False) def start(self) -> None: """Start the background checker thread.""" self._stop_event.clear() self._checker_thread = Thread(target=self._check_loop, daemon=True) self._checker_thread.start() def stop(self) -> None: """Stop the background checker thread.""" self._stop_event.set() if self._checker_thread: self._checker_thread.join(timeout=2.0) def _check_loop(self) -> None: """Background loop that checks for failed nodes.""" while not self._stop_event.is_set(): self._check_all_nodes() self._stop_event.wait(self.check_interval) def _check_all_nodes(self) -> None: """Check all nodes for timeout.""" now = time.time() callbacks: list[tuple[str, NodeState]] = [] with self._lock: for node in self._nodes.values(): elapsed = now - node.last_heartbeat if elapsed > self.timeout_seconds and node.state == NodeState.HEALTHY: node.state = NodeState.UNHEALTHY callbacks.append((node.node_id, NodeState.UNHEALTHY)) # Fire callbacks outside lock to prevent deadlock for node_id, state in callbacks: if self.on_state_change: self.on_state_change(node_id, state) def heartbeat(self, node_id: str) -> None: """Record a heartbeat from a node (thread-safe).""" now = time.time() callback: tuple[str, NodeState] | None = None with self._lock: if node_id in self._nodes: node = self._nodes[node_id] node.last_heartbeat = now if node.state == NodeState.UNHEALTHY: node.state = NodeState.HEALTHY callback = (node_id, NodeState.HEALTHY) else: self._nodes[node_id] = Node(node_id=node_id, last_heartbeat=now) if callback and self.on_state_change: self.on_state_change(callback[0], callback[1]) def get_state(self, node_id: str) -> NodeState | None: with self._lock: return self._nodes[node_id].state if node_id in self._nodes else None def get_all_nodes(self) -> dict[str, NodeState]: with self._lock: return {nid: n.state for nid, n in self._nodes.items()}
Question 26 - Bitmap Index
Difficulty: 2 / 10
Approximate lines of code: 120 LoC
Tags: data-structures, storage
Description
A bitmap index accelerates queries on low-cardinality columns (columns with few distinct values, like “status” or “color”). For each distinct value in a column, we maintain a bitmap where bit i is 1 if row i has that value. Queries become bitwise operations: AND for intersection, OR for union, NOT for complement. This is how databases like Oracle and PostgreSQL speed up OLAP queries.
For a table with 5 rows and a “color” column with values [“red”, “blue”, “red”, “green”, “blue”], we create three bitmaps:
red: 10100 (rows 0, 2)
blue: 01001 (rows 1, 4)
green: 00010 (row 3)
Finding rows where color=“red” AND size=“S” becomes red_bitmap & small_bitmap, which is O(n/64) operations using 64-bit words.
Part A: Basic Bitmap Operations
Problem: Part A
Implement a Bitmap class supporting AND, OR, NOT operations and a method to get matching row indices. Use Python’s arbitrary-precision integers as the underlying bit storage.
Build a Table class that indexes multiple columns and supports queries with AND/OR across conditions.
table = Table()table.add_column("color", ["red", "blue", "red", "green", "blue"])table.add_column("size", ["S", "M", "L", "S", "M"])# Internal state:# indexes = {# "color": BitmapIndex with bitmaps for red, blue, green# "size": BitmapIndex with bitmaps for S, M, L# }# Find rows where color="blue" AND size="M"table.query([("color", "blue"), ("size", "M")], op="AND")# blue = 0b10010, M = 0b10010# AND result = 0b10010 -> [1, 4]# Find rows where color="red" OR color="green"table.query([("color", "red"), ("color", "green")], op="OR")# red = 0b00101, green = 0b01000# OR result = 0b01101 -> [0, 2, 3]
Part C: NOT Operation and Compression Discussion
Problem: Part C
Implement NOT correctly (must mask to valid rows only). Discuss compression schemes for sparse bitmaps (RLE, WAH).
red = index.get("red") # bits=0b00101, size=5not_red = ~red # Must mask: ~0b00101 & 0b11111 = 0b11010not_red.get_matching_rows() # [1, 3, 4]# Without masking, ~0b00101 = ...111111111111111111111010 (infinite 1s)# This would return invalid row indices# Compression discussion:# - Run-Length Encoding: Store (value, count) pairs# [1,1,1,1,0,0,0,0] -> [(1,4), (0,4)]# - Word-Aligned Hybrid (WAH): Fixed word size with fill/literal encoding# - When to compress: High cardinality columns, sparse data
Interview comments
Interview comments
Edge cases to probe:
What happens when querying for a value that doesn’t exist in the column?
How does NOT behave if you don’t mask to the valid row range?
What’s the space complexity for a column with N distinct values and M rows?
When would bitmap indexes perform worse than B-tree indexes?
Common mistakes:
NOT without masking to valid row range (returns infinite/invalid rows)
Using Python’s hash() which is non-deterministic across runs
Not handling empty result sets gracefully
Off-by-one errors when converting bitmap to row indices
Assuming bitmap indexes are always better (they’re not for high-cardinality columns)
Code solutions
Code solutions
Solution 1 uses Python’s arbitrary-precision integers for raw bitmap storage with basic AND/OR/NOT operations. Solution 2 implements Run-Length Encoding (RLE) compression storing (value, count) pairs for efficient handling of clustered data. Solution 3 implements Word-Aligned Hybrid (WAH) compression with fixed-size words and fill/literal encoding for CPU-aligned operations. The key difference is the storage format and compression tradeoffs: raw bits for simplicity, RLE for clustered patterns, WAH for practical CPU efficiency. Core techniques: bitwise operations, bit masking for NOT, run-length encoding, word-aligned compression.
Solution 1: Basic Implementation
Basic implementation using Python’s arbitrary-precision integers for bitmaps. Implements AND, OR, NOT with proper masking, and builds indexes from column values.
"""Bitmap Index - Solution 1: Basic ImplementationSimple bitmap indexing with AND/OR/NOT operations, no compression."""from dataclasses import dataclass, fieldfrom typing import Any@dataclassclass Bitmap: """A bitmap representing rows matching a condition.""" bits: int # Using Python's arbitrary-precision int as bit array size: int # Number of rows def __and__(self, other: "Bitmap") -> "Bitmap": return Bitmap(self.bits & other.bits, self.size) def __or__(self, other: "Bitmap") -> "Bitmap": return Bitmap(self.bits | other.bits, self.size) def __invert__(self) -> "Bitmap": mask = (1 << self.size) - 1 return Bitmap(~self.bits & mask, self.size) def get_matching_rows(self) -> list[int]: """Return indices of rows where bit is set.""" return [i for i in range(self.size) if self.bits & (1 << i)]@dataclassclass BitmapIndex: """Bitmap index for a single column.""" column_name: str bitmaps: dict[Any, Bitmap] = field(default_factory=dict) row_count: int = 0 def build(self, values: list[Any]) -> None: """Build bitmap index from column values.""" self.row_count = len(values) value_rows: dict[Any, int] = {} for i, value in enumerate(values): if value not in value_rows: value_rows[value] = 0 value_rows[value] |= (1 << i) self.bitmaps = {v: Bitmap(bits, self.row_count) for v, bits in value_rows.items()} def get(self, value: Any) -> Bitmap: """Get bitmap for a specific value.""" return self.bitmaps.get(value, Bitmap(0, self.row_count))@dataclassclass Table: """Simple table with bitmap indexes.""" columns: dict[str, list[Any]] = field(default_factory=dict) indexes: dict[str, BitmapIndex] = field(default_factory=dict) def add_column(self, name: str, values: list[Any]) -> None: self.columns[name] = values index = BitmapIndex(name) index.build(values) self.indexes[name] = index def query(self, conditions: list[tuple[str, Any]], op: str = "AND") -> list[int]: """Query with multiple conditions combined with AND or OR.""" if not conditions: return [] col, val = conditions[0] result = self.indexes[col].get(val) for col, val in conditions[1:]: bitmap = self.indexes[col].get(val) result = result & bitmap if op == "AND" else result | bitmap return result.get_matching_rows()if __name__ == "__main__": # Smoke tests table = Table() table.add_column("color", ["red", "blue", "red", "green", "blue"]) table.add_column("size", ["S", "M", "L", "S", "M"]) # Test single condition assert table.query([("color", "red")]) == [0, 2], "Red rows should be 0, 2" # Test AND query assert table.query([("color", "blue"), ("size", "M")], "AND") == [1, 4] # Test OR query assert table.query([("color", "red"), ("color", "green")], "OR") == [0, 2, 3] # Test NOT operation red_bitmap = table.indexes["color"].get("red") not_red = ~red_bitmap assert not_red.get_matching_rows() == [1, 3, 4], "NOT red should be 1, 3, 4" print("All tests passed!")
Solution 2: Run-Length Encoding Compression
Implements Run-Length Encoding (RLE) compression for bitmaps. Stores (bit_value, count) pairs instead of raw bits. Efficient for clustered data where same values appear consecutively.
"""Bitmap Index - Solution 2: Run-Length Encoding CompressionCompressed bitmap using RLE for sparse data."""from dataclasses import dataclass, fieldfrom typing import Any@dataclassclass RLEBitmap: """Run-length encoded bitmap: stores (value, count) pairs.""" runs: list[tuple[int, int]] = field(default_factory=list) # (bit_value, run_length) size: int = 0 @classmethod def from_bits(cls, bits: int, size: int) -> "RLEBitmap": """Create RLE bitmap from raw bits.""" if size == 0: return cls([], 0) runs = [] current_bit = bits & 1 count = 0 for i in range(size): bit = (bits >> i) & 1 if bit == current_bit: count += 1 else: runs.append((current_bit, count)) current_bit = bit count = 1 runs.append((current_bit, count)) return cls(runs, size) def to_bits(self) -> int: """Decompress to raw bits.""" bits = 0 pos = 0 for bit_val, count in self.runs: if bit_val: for i in range(count): bits |= (1 << (pos + i)) pos += count return bits def __and__(self, other: "RLEBitmap") -> "RLEBitmap": return RLEBitmap.from_bits(self.to_bits() & other.to_bits(), self.size) def __or__(self, other: "RLEBitmap") -> "RLEBitmap": return RLEBitmap.from_bits(self.to_bits() | other.to_bits(), self.size) def __invert__(self) -> "RLEBitmap": return RLEBitmap([(1 - v, c) for v, c in self.runs], self.size) def get_matching_rows(self) -> list[int]: result = [] pos = 0 for bit_val, count in self.runs: if bit_val: result.extend(range(pos, pos + count)) pos += count return result def compression_ratio(self) -> float: """Return compression ratio (lower is better).""" if self.size == 0: return 0.0 return len(self.runs) * 2 / self.size # 2 values per run vs 1 bit per row@dataclassclass CompressedBitmapIndex: """Bitmap index using RLE compression.""" column_name: str bitmaps: dict[Any, RLEBitmap] = field(default_factory=dict) row_count: int = 0 def build(self, values: list[Any]) -> None: self.row_count = len(values) value_bits: dict[Any, int] = {} for i, value in enumerate(values): if value not in value_bits: value_bits[value] = 0 value_bits[value] |= (1 << i) self.bitmaps = { v: RLEBitmap.from_bits(bits, self.row_count) for v, bits in value_bits.items() } def get(self, value: Any) -> RLEBitmap: return self.bitmaps.get(value, RLEBitmap([(0, self.row_count)], self.row_count))if __name__ == "__main__": # Test RLE compression bitmap = RLEBitmap.from_bits(0b11110000, 8) assert bitmap.runs == [(0, 4), (1, 4)], f"Unexpected runs: {bitmap.runs}" assert bitmap.to_bits() == 0b11110000 # Test operations a = RLEBitmap.from_bits(0b1100, 4) b = RLEBitmap.from_bits(0b1010, 4) assert (a & b).to_bits() == 0b1000 assert (a | b).to_bits() == 0b1110 assert (~a).to_bits() == 0b0011 # Test index index = CompressedBitmapIndex("status") index.build(["active"] * 100 + ["inactive"] * 100) active_bitmap = index.get("active") assert active_bitmap.get_matching_rows() == list(range(100)) print(f"Compression ratio for clustered data: {active_bitmap.compression_ratio():.3f}") # Sparse data compresses well with RLE sparse = RLEBitmap.from_bits(0b1, 1000) print(f"Compression ratio for sparse data: {sparse.compression_ratio():.3f}") print("All tests passed!")
Solution 3: Word-Aligned Hybrid (WAH) Style
Implements Word-Aligned Hybrid (WAH) compression. Uses fixed-size words with fill/literal encoding. More practical for real systems as it aligns with CPU word sizes.
"""Bitmap Index - Solution 3: Word-Aligned Hybrid (WAH) StylePractical approach using fixed-size words with fill/literal encoding."""from dataclasses import dataclass, fieldfrom typing import Anyfrom enum import Enumclass WordType(Enum): LITERAL = 0 # Word contains raw bits FILL_ZEROS = 1 # Word represents N consecutive zero-words FILL_ONES = 2 # Word represents N consecutive one-wordsWORD_SIZE = 31 # Bits per word (leaving 1 for type indicator in real impl)@dataclassclass WAHBitmap: """Word-aligned hybrid compressed bitmap.""" words: list[tuple[WordType, int]] = field(default_factory=list) size: int = 0 @classmethod def from_bits(cls, bits: int, size: int) -> "WAHBitmap": """Compress raw bits into WAH format.""" words = [] num_words = (size + WORD_SIZE - 1) // WORD_SIZE i = 0 while i < num_words: word = (bits >> (i * WORD_SIZE)) & ((1 << WORD_SIZE) - 1) all_zeros = word == 0 all_ones = word == (1 << min(WORD_SIZE, size - i * WORD_SIZE)) - 1 if all_zeros or all_ones: # Count consecutive fill words fill_type = WordType.FILL_ZEROS if all_zeros else WordType.FILL_ONES count = 1 while i + count < num_words: next_word = (bits >> ((i + count) * WORD_SIZE)) & ((1 << WORD_SIZE) - 1) if (fill_type == WordType.FILL_ZEROS and next_word == 0) or \ (fill_type == WordType.FILL_ONES and next_word == (1 << WORD_SIZE) - 1): count += 1 else: break words.append((fill_type, count)) i += count else: words.append((WordType.LITERAL, word)) i += 1 return cls(words, size) def to_bits(self) -> int: """Decompress to raw bits.""" bits = 0 pos = 0 for word_type, value in self.words: if word_type == WordType.LITERAL: bits |= value << (pos * WORD_SIZE) pos += 1 elif word_type == WordType.FILL_ONES: for _ in range(value): bits |= ((1 << WORD_SIZE) - 1) << (pos * WORD_SIZE) pos += 1 else: # FILL_ZEROS pos += value return bits def __and__(self, other: "WAHBitmap") -> "WAHBitmap": return WAHBitmap.from_bits(self.to_bits() & other.to_bits(), self.size) def __or__(self, other: "WAHBitmap") -> "WAHBitmap": return WAHBitmap.from_bits(self.to_bits() | other.to_bits(), self.size) def __invert__(self) -> "WAHBitmap": mask = (1 << self.size) - 1 return WAHBitmap.from_bits(~self.to_bits() & mask, self.size) def get_matching_rows(self) -> list[int]: bits = self.to_bits() return [i for i in range(self.size) if bits & (1 << i)] def word_count(self) -> int: return len(self.words)@dataclassclass WAHIndex: """Bitmap index using WAH compression.""" bitmaps: dict[Any, WAHBitmap] = field(default_factory=dict) row_count: int = 0 def build(self, values: list[Any]) -> None: self.row_count = len(values) value_bits: dict[Any, int] = {} for i, v in enumerate(values): value_bits.setdefault(v, 0) value_bits[v] |= (1 << i) self.bitmaps = {v: WAHBitmap.from_bits(b, self.row_count) for v, b in value_bits.items()} def get(self, value: Any) -> WAHBitmap: return self.bitmaps.get(value, WAHBitmap.from_bits(0, self.row_count))if __name__ == "__main__": # Test WAH compression with clustered data clustered = (1 << 100) - 1 # First 100 bits set wah = WAHBitmap.from_bits(clustered, 200) assert wah.to_bits() == clustered print(f"Clustered data: {wah.word_count()} words for 200 bits") # Test operations a = WAHBitmap.from_bits(0b1100, 4) b = WAHBitmap.from_bits(0b1010, 4) assert (a & b).to_bits() == 0b1000 assert (a | b).to_bits() == 0b1110 # Test index index = WAHIndex() index.build(["A"] * 50 + ["B"] * 50 + ["A"] * 50) assert index.get("A").get_matching_rows() == list(range(50)) + list(range(100, 150)) print("All tests passed!")
Question 27 - Rope
Difficulty: 8 / 10
Approximate lines of code: 130 LoC
Tags: data-structures
Description
A rope is a binary tree data structure for efficiently representing and manipulating long strings. Unlike a standard string where concatenation is O(n), a rope achieves O(1) concatenation by creating a new parent node that points to both strings. This makes ropes ideal for text editors where users frequently insert, delete, and concatenate text at arbitrary positions.
The key insight is that leaf nodes store actual string fragments, while internal nodes store only metadata. Each internal node tracks its “weight” - the total length of its left subtree - which enables O(log n) character indexing by comparing the target index against the weight to decide whether to go left or right.
Part A: Structure and Concatenation
Problem: Part A
Implement the basic rope structure with concat() that joins two ropes in O(1) time, and to_string() that flattens the rope back to a string.
[weight=7] <- new root, weight = len(left subtree)
/ \
"Hello, " "World!" <- leaf nodes with actual text
Part B: Index Access
Problem: Part B
Implement index(i) that retrieves the character at position i in O(log n) time for a balanced tree.
The algorithm: at each node, if i < weight, recurse left. Otherwise, recurse right with i - weight.
r3.index(0) # 'H' - go left (0 < 7)r3.index(7) # 'W' - go right (7 >= 7), then index 0 in right subtreer3.index(12) # '!' - go right, then index 5 in "World!"
Part C: Split
Problem: Part C
Implement split(i) that divides a rope at position i, returning two ropes (left, right).
This is the trickiest operation. You must recursively descend to find the split point, potentially splitting a leaf node itself, then reconstruct the tree structure.
left, right = r3.split(7)left.to_string() # "Hello, "right.to_string() # "World!"# Splitting mid-leaf:r = Rope("abcdef")l, r = r.split(3)l.to_string() # "abc"r.to_string() # "def"
Interview comments
Interview comments
Edge cases to probe:
What happens when concatenating with an empty rope?
What if split position is 0 or equal to length?
How do you handle negative indices?
What’s the time complexity if the tree becomes unbalanced (repeated concat on one side)?
Common mistakes:
Confusing weight (left subtree length) with total node length
Off-by-one errors in split, especially at leaf boundaries
Not handling empty rope/None cases
Forgetting to recurse with adjusted index (i - weight) when going right
Code solutions
Code solutions
Solutions Overview
Solution 1 is a basic recursive implementation using weight-based tree traversal. Solution 2 adds AVL-style balancing with rotations to maintain O(log n) height. Solution 3 takes a functional/immutable approach with frozen dataclasses and Python operator support. The key difference is how they handle tree balance and mutability. Core techniques: binary tree traversal, weight-based indexing, recursive split with tree reconstruction.
Solution 1: Basic recursive implementation
Basic recursive implementation using dataclasses. Uses RopeNode with weight field for left subtree length. Split recursively descends and reconstructs parent nodes.
"""Rope Data Structure - Solution 1: Basic ImplementationA rope is a binary tree where leaves contain string fragments.Supports O(1) concatenation, O(log n) indexing with balanced tree."""from __future__ import annotationsfrom dataclasses import dataclassfrom typing import Optional@dataclassclass RopeNode: """A node in the rope tree. Leaves have text, internal nodes have children.""" left: Optional[RopeNode] = None right: Optional[RopeNode] = None text: Optional[str] = None weight: int = 0 # Length of left subtree (or text length for leaves) @staticmethod def from_string(s: str) -> RopeNode: """Create a leaf node from a string.""" return RopeNode(text=s, weight=len(s)) def is_leaf(self) -> bool: return self.text is not Noneclass Rope: """Rope data structure for efficient string operations.""" def __init__(self, text: str = ""): self.root: Optional[RopeNode] = RopeNode.from_string(text) if text else None def __len__(self) -> int: return self._total_length(self.root) def _total_length(self, node: Optional[RopeNode]) -> int: if node is None: return 0 if node.is_leaf(): return node.weight return node.weight + self._total_length(node.right) def concat(self, other: Rope) -> Rope: """Concatenate two ropes in O(1) time.""" result = Rope() if self.root is None: result.root = other.root elif other.root is None: result.root = self.root else: result.root = RopeNode( left=self.root, right=other.root, weight=len(self) ) return result def index(self, i: int) -> str: """Get character at position i in O(log n) time for balanced tree.""" if i < 0 or i >= len(self): raise IndexError(f"Index {i} out of range") return self._index(self.root, i) def _index(self, node: RopeNode, i: int) -> str: if node.is_leaf(): return node.text[i] if i < node.weight: return self._index(node.left, i) return self._index(node.right, i - node.weight) def split(self, i: int) -> tuple[Rope, Rope]: """Split rope at position i, returning (left, right) ropes.""" if i <= 0: return Rope(), self if i >= len(self): return self, Rope() left_node, right_node = self._split(self.root, i) left_rope, right_rope = Rope(), Rope() left_rope.root, right_rope.root = left_node, right_node return left_rope, right_rope def _split(self, node: RopeNode, i: int) -> tuple[Optional[RopeNode], Optional[RopeNode]]: if node.is_leaf(): left = RopeNode.from_string(node.text[:i]) if i > 0 else None right = RopeNode.from_string(node.text[i:]) if i < len(node.text) else None return left, right if i < node.weight: left, mid = self._split(node.left, i) right = RopeNode(left=mid, right=node.right, weight=node.weight - i) if mid or node.right else node.right return left, right elif i > node.weight: mid, right = self._split(node.right, i - node.weight) left = RopeNode(left=node.left, right=mid, weight=node.weight) if node.left or mid else mid return left, right return node.left, node.right def to_string(self) -> str: """Convert rope back to string.""" return self._collect(self.root) def _collect(self, node: Optional[RopeNode]) -> str: if node is None: return "" if node.is_leaf(): return node.text return self._collect(node.left) + self._collect(node.right)
Solution 2: AVL-style balancing
Adds AVL-style tree balancing to maintain O(log n) height. Tracks node height and rebalances after concat using rotations. Updates weight during rotations.
"""Rope Data Structure - Solution 2: With Tree BalancingImplements AVL-style balancing to maintain O(log n) operations."""from __future__ import annotationsfrom dataclasses import dataclass, fieldfrom typing import Optional@dataclassclass RopeNode: left: Optional[RopeNode] = None right: Optional[RopeNode] = None text: Optional[str] = None weight: int = 0 height: int = 1 @staticmethod def leaf(s: str) -> RopeNode: return RopeNode(text=s, weight=len(s), height=1) def is_leaf(self) -> bool: return self.text is not Nonedef get_height(node: Optional[RopeNode]) -> int: return node.height if node else 0def get_balance(node: RopeNode) -> int: return get_height(node.left) - get_height(node.right)def update_height(node: RopeNode) -> None: node.height = 1 + max(get_height(node.left), get_height(node.right))def total_length(node: Optional[RopeNode]) -> int: if node is None: return 0 if node.is_leaf(): return node.weight return node.weight + total_length(node.right)def rotate_right(y: RopeNode) -> RopeNode: x = y.left t2 = x.right x.right = y y.left = t2 y.weight = total_length(y.left) update_height(y) update_height(x) return xdef rotate_left(x: RopeNode) -> RopeNode: y = x.right t2 = y.left y.left = x x.right = t2 x.weight = total_length(x.left) y.weight = total_length(y.left) update_height(x) update_height(y) return ydef balance(node: RopeNode) -> RopeNode: update_height(node) bal = get_balance(node) if bal > 1: if get_balance(node.left) < 0: node.left = rotate_left(node.left) return rotate_right(node) if bal < -1: if get_balance(node.right) > 0: node.right = rotate_right(node.right) return rotate_left(node) return nodeclass Rope: def __init__(self, text: str = ""): self.root: Optional[RopeNode] = RopeNode.leaf(text) if text else None def __len__(self) -> int: return total_length(self.root) def concat(self, other: Rope) -> Rope: result = Rope() if not self.root: result.root = other.root elif not other.root: result.root = self.root else: node = RopeNode(left=self.root, right=other.root, weight=len(self)) result.root = balance(node) return result def index(self, i: int) -> str: if i < 0 or i >= len(self): raise IndexError(f"Index {i} out of range") node = self.root while node: if node.is_leaf(): return node.text[i] if i < node.weight: node = node.left else: i -= node.weight node = node.right raise IndexError("Invalid rope structure") def split(self, i: int) -> tuple[Rope, Rope]: if i <= 0: return Rope(), self if i >= len(self): return self, Rope() left_node, right_node = self._split(self.root, i) left_rope, right_rope = Rope(), Rope() left_rope.root = left_node right_rope.root = right_node return left_rope, right_rope def _split(self, node: RopeNode, i: int) -> tuple[Optional[RopeNode], Optional[RopeNode]]: if node.is_leaf(): l = RopeNode.leaf(node.text[:i]) if i > 0 else None r = RopeNode.leaf(node.text[i:]) if i < len(node.text) else None return l, r if i <= node.weight: left, mid = self._split(node.left, i) right = self._merge(mid, node.right) return left, right mid, right = self._split(node.right, i - node.weight) left = self._merge(node.left, mid) return left, right def _merge(self, a: Optional[RopeNode], b: Optional[RopeNode]) -> Optional[RopeNode]: if not a: return b if not b: return a return balance(RopeNode(left=a, right=b, weight=total_length(a))) def to_string(self) -> str: parts = [] self._collect(self.root, parts) return "".join(parts) def _collect(self, node: Optional[RopeNode], parts: list[str]) -> None: if node is None: return if node.is_leaf(): parts.append(node.text) else: self._collect(node.left, parts) self._collect(node.right, parts)
Solution 3: Functional/immutable approach
Functional/immutable approach using frozen dataclasses. Stores explicit size field. Supports Python operators (+, []), negative indexing, and insert/delete operations built on split.
"""Rope Data Structure - Solution 3: Functional Style with ImmutabilityUses immutable nodes for easier reasoning and thread safety."""from __future__ import annotationsfrom dataclasses import dataclassfrom typing import Optional, Iterator@dataclass(frozen=True)class Node: """Immutable rope node.""" left: Optional[Node] = None right: Optional[Node] = None text: Optional[str] = None weight: int = 0 size: int = 0 @staticmethod def leaf(s: str) -> Node: return Node(text=s, weight=len(s), size=len(s)) @staticmethod def branch(left: Node, right: Node) -> Node: return Node(left=left, right=right, weight=left.size, size=left.size + right.size) def is_leaf(self) -> bool: return self.text is not None@dataclass(frozen=True)class Rope: """Immutable rope data structure.""" root: Optional[Node] = None @staticmethod def from_string(s: str) -> Rope: return Rope(Node.leaf(s)) if s else Rope() def __len__(self) -> int: return self.root.size if self.root else 0 def __add__(self, other: Rope) -> Rope: """Concatenate using + operator.""" if not self.root: return other if not other.root: return self return Rope(Node.branch(self.root, other.root)) def __getitem__(self, i: int) -> str: """Index using [] operator.""" if i < 0: i += len(self) if i < 0 or i >= len(self): raise IndexError(f"Index {i} out of range") return self._index(self.root, i) def _index(self, node: Node, i: int) -> str: if node.is_leaf(): return node.text[i] if i < node.weight: return self._index(node.left, i) return self._index(node.right, i - node.weight) def split(self, i: int) -> tuple[Rope, Rope]: """Split into two ropes at position i.""" if i <= 0: return Rope(), self if i >= len(self): return self, Rope() left, right = self._split(self.root, i) return Rope(left), Rope(right) def _split(self, node: Node, i: int) -> tuple[Optional[Node], Optional[Node]]: if node.is_leaf(): left = Node.leaf(node.text[:i]) if i > 0 else None right = Node.leaf(node.text[i:]) if i < len(node.text) else None return left, right if i <= node.weight: left, mid = self._split(node.left, i) right = self._join(mid, node.right) return left, right mid, right = self._split(node.right, i - node.weight) left = self._join(node.left, mid) return left, right def _join(self, a: Optional[Node], b: Optional[Node]) -> Optional[Node]: if not a: return b if not b: return a return Node.branch(a, b) def insert(self, i: int, s: str) -> Rope: """Insert string at position i, returning new rope.""" left, right = self.split(i) return left + Rope.from_string(s) + right def delete(self, start: int, end: int) -> Rope: """Delete characters from start to end, returning new rope.""" left, _ = self.split(start) _, right = self.split(end) return left + right def __iter__(self) -> Iterator[str]: """Iterate over characters.""" yield from self._iterate(self.root) def _iterate(self, node: Optional[Node]) -> Iterator[str]: if node is None: return if node.is_leaf(): yield from node.text else: yield from self._iterate(node.left) yield from self._iterate(node.right) def to_string(self) -> str: return "".join(self) def __repr__(self) -> str: return f"Rope({self.to_string()!r})"
Question 28 - Spell Checker
Difficulty: 6 / 10
Approximate lines of code: 80 LoC
Tags: algorithms, data-structures
Description
A spell checker validates words against a dictionary and suggests corrections for misspelled words. The core algorithm uses edit distance (Levenshtein distance) to find dictionary words within 1-2 edits of the misspelled word. Suggestions are ranked by a combination of edit distance and word frequency - “the” is more likely than “thee” even if both are 1 edit away from “teh”.
The naive approach computes edit distance against every dictionary word (O(n * m^2) where n=dictionary size, m=word length). Optimizations include BK-trees for faster neighbor search, or generating all possible edits and filtering by dictionary membership.
Edit distance (Levenshtein) counts minimum operations to transform one string to another:
Insert: “helo” → “hello” (1 edit)
Delete: “helloo” → “hello” (1 edit)
Substitute: “hallo” → “hello” (1 edit)
Transpose: “hlelo” → “hello” (can count as 1 or 2 depending on variant)
Standard DP solution:
def edit_distance(s1, s2): m, n = len(s1), len(s2) dp = [[0] * (n+1) for _ in range(m+1)] for i in range(m+1): dp[i][0] = i for j in range(n+1): dp[0][j] = j for i in range(1, m+1): for j in range(1, n+1): if s1[i-1] == s2[j-1]: dp[i][j] = dp[i-1][j-1] else: dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) return dp[m][n]
Part C: Ranking with Combined Score
Problem: Part C
Rank suggestions by both edit distance and word frequency. Lower distance wins, then higher frequency breaks ties.
# Ranking tuple: (edit_distance, -frequency, word)# Sort ascending: lower distance first, higher frequency first (via negation)candidates = [ (1, -200, "hello"), # distance=1, freq=200 (1, -100, "help"), # distance=1, freq=100 (2, -1000, "the"), # distance=2, freq=1000]candidates.sort()# Result: [hello, help, the]# Even though "the" has highest frequency, distance=2 loses to distance=1
Advanced: Weight keyboard proximity (typing “r” instead of “e” is common since they’re adjacent).
Interview comments
Interview comments
Edge cases to probe:
Empty string input?
Word already in dictionary (return empty suggestions)?
Very long words (edit distance computation cost)?
What if multiple words have same distance and frequency?
Common mistakes:
Off-by-one in DP table indexing (using s1[i] instead of s1[i-1])
Not handling case sensitivity consistently
Sorting by frequency ascending instead of descending
Computing edit distance for entire dictionary (slow) - consider BK-tree or candidate generation
Forgetting to handle the case where word is already correct
Code solutions
Code solutions
Solution 1 uses basic Levenshtein distance computed against every dictionary word - simple but O(n * m^2) per query. Solution 2 uses a BK-tree data structure for faster nearest neighbor search, allowing branch pruning based on the triangle inequality property of edit distance. Solution 3 takes a generate-and-filter approach: instead of checking every dictionary word, it generates all strings within 1-2 edits and filters by dictionary membership, also incorporating keyboard proximity for common typos. These vary in their search strategy: brute force, tree-based pruning, or candidate generation.
Core techniques: Levenshtein distance DP, BK-tree construction and search, edit operation generation (insert, delete, substitute, transpose).
Solution 1: Basic Edit Distance Approach
Basic approach using standard Levenshtein distance DP. Computes edit distance against every dictionary word. Simple but O(n * m^2) per query.
"""Spell Checker - Solution 1: Basic Edit Distance ApproachUses Levenshtein distance with frequency-weighted suggestions."""from dataclasses import dataclass, fieldfrom typing import Optional@dataclassclass SpellChecker: """Spell checker using edit distance and word frequency.""" dictionary: set[str] = field(default_factory=set) word_frequency: dict[str, int] = field(default_factory=dict) max_edit_distance: int = 2 def add_word(self, word: str, frequency: int = 1) -> None: """Add a word to the dictionary with its frequency.""" word = word.lower() self.dictionary.add(word) self.word_frequency[word] = frequency def is_correct(self, word: str) -> bool: """Check if a word is spelled correctly.""" return word.lower() in self.dictionary def edit_distance(self, word1: str, word2: str) -> int: """Calculate Levenshtein edit distance between two words.""" m, n = len(word1), len(word2) dp = [[0] * (n + 1) for _ in range(m + 1)] for i in range(m + 1): dp[i][0] = i for j in range(n + 1): dp[0][j] = j for i in range(1, m + 1): for j in range(1, n + 1): if word1[i - 1] == word2[j - 1]: dp[i][j] = dp[i - 1][j - 1] else: dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) return dp[m][n] def get_suggestions(self, word: str, max_suggestions: int = 5) -> list[str]: """Get spelling suggestions sorted by edit distance then frequency.""" word = word.lower() if self.is_correct(word): return [] candidates: list[tuple[int, int, str]] = [] for dict_word in self.dictionary: distance = self.edit_distance(word, dict_word) if distance <= self.max_edit_distance: freq = self.word_frequency.get(dict_word, 0) candidates.append((distance, -freq, dict_word)) candidates.sort() return [c[2] for c in candidates[:max_suggestions]]
Solution 2: BK-Tree for Efficient Lookup
BK-tree approach for faster nearest neighbor search in edit distance space. Tree structure allows pruning branches that cannot contain matches within the distance threshold.
"""Spell Checker - Solution 2: BK-Tree for Efficient LookupUses a BK-tree for faster nearest neighbor search in edit distance space."""from dataclasses import dataclass, fieldfrom typing import Optionaldef levenshtein(s1: str, s2: str) -> int: """Calculate Levenshtein distance between two strings.""" if len(s1) < len(s2): s1, s2 = s2, s1 if len(s2) == 0: return len(s1) prev_row = list(range(len(s2) + 1)) for i, c1 in enumerate(s1): curr_row = [i + 1] for j, c2 in enumerate(s2): insertions = prev_row[j + 1] + 1 deletions = curr_row[j] + 1 substitutions = prev_row[j] + (c1 != c2) curr_row.append(min(insertions, deletions, substitutions)) prev_row = curr_row return prev_row[-1]@dataclassclass BKNode: """Node in a BK-tree.""" word: str children: dict[int, "BKNode"] = field(default_factory=dict)@dataclassclass SpellChecker: """Spell checker using BK-tree for efficient fuzzy matching.""" root: Optional[BKNode] = None dictionary: set[str] = field(default_factory=set) word_frequency: dict[str, int] = field(default_factory=dict) def add_word(self, word: str, frequency: int = 1) -> None: """Add a word to the BK-tree and dictionary.""" word = word.lower() self.dictionary.add(word) self.word_frequency[word] = frequency if self.root is None: self.root = BKNode(word) return node = self.root while True: dist = levenshtein(word, node.word) if dist == 0: return if dist in node.children: node = node.children[dist] else: node.children[dist] = BKNode(word) return def is_correct(self, word: str) -> bool: """Check if a word is spelled correctly.""" return word.lower() in self.dictionary def _search_bk(self, node: BKNode, word: str, max_dist: int) -> list[tuple[str, int]]: """Recursively search BK-tree for words within max_dist.""" results = [] dist = levenshtein(word, node.word) if dist <= max_dist: results.append((node.word, dist)) for d in range(dist - max_dist, dist + max_dist + 1): if d in node.children: results.extend(self._search_bk(node.children[d], word, max_dist)) return results def get_suggestions(self, word: str, max_suggestions: int = 5) -> list[str]: """Get spelling suggestions using BK-tree search.""" word = word.lower() if self.is_correct(word) or self.root is None: return [] matches = self._search_bk(self.root, word, max_dist=2) scored = [(d, -self.word_frequency.get(w, 0), w) for w, d in matches] scored.sort() return [w for _, _, w in scored[:max_suggestions]]
Solution 3: Generate-and-Filter with Common Typo Patterns
Generate-and-filter approach. Instead of checking every dictionary word, generates all strings within 1-2 edits of the input, then filters by dictionary membership. Also includes keyboard proximity for common typos.
"""Spell Checker - Solution 3: Generate-and-Filter with Common Typo PatternsGenerates candidate corrections by applying edit operations, then filters by dictionary."""from dataclasses import dataclass, fieldfrom typing import IteratorKEYBOARD_NEIGHBORS: dict[str, str] = { 'a': 'qwsz', 'b': 'vghn', 'c': 'xdfv', 'd': 'erfcxs', 'e': 'rdsw', 'f': 'rtgvcd', 'g': 'tyhbvf', 'h': 'yujnbg', 'i': 'uojk', 'j': 'uikmnh', 'k': 'iolmj', 'l': 'opk', 'm': 'njk', 'n': 'bhjm', 'o': 'iplk', 'p': 'ol', 'q': 'wa', 'r': 'edft', 's': 'wedxza', 't': 'rfgy', 'u': 'yhji', 'v': 'cfgb', 'w': 'qase', 'x': 'zsdc', 'y': 'tghu', 'z': 'asx',}@dataclassclass SpellChecker: """Spell checker that generates candidates via edit operations.""" dictionary: set[str] = field(default_factory=set) word_frequency: dict[str, int] = field(default_factory=dict) def add_word(self, word: str, frequency: int = 1) -> None: """Add a word to the dictionary.""" word = word.lower() self.dictionary.add(word) self.word_frequency[word] = frequency def is_correct(self, word: str) -> bool: """Check if a word is spelled correctly.""" return word.lower() in self.dictionary def _edits1(self, word: str) -> Iterator[str]: """Generate all strings that are one edit away.""" letters = 'abcdefghijklmnopqrstuvwxyz' splits = [(word[:i], word[i:]) for i in range(len(word) + 1)] for left, right in splits: if right: yield left + right[1:] # deletion if len(right) > 1: yield left + right[1] + right[0] + right[2:] # transposition for c in letters: yield left + c + right # insertion if right: yield left + c + right[1:] # substitution def _keyboard_edits(self, word: str) -> Iterator[str]: """Generate edits based on keyboard proximity (common typos).""" for i, char in enumerate(word): if char in KEYBOARD_NEIGHBORS: for neighbor in KEYBOARD_NEIGHBORS[char]: yield word[:i] + neighbor + word[i + 1:] def _edits2(self, word: str) -> Iterator[str]: """Generate all strings that are two edits away.""" seen = set() for e1 in self._edits1(word): for e2 in self._edits1(e1): if e2 not in seen: seen.add(e2) yield e2 def get_suggestions(self, word: str, max_suggestions: int = 5) -> list[str]: """Get spelling suggestions by generating and filtering candidates.""" word = word.lower() if self.is_correct(word): return [] candidates: list[tuple[int, int, str]] = [] # Priority 1: keyboard typos (distance 0.5 conceptually) for candidate in self._keyboard_edits(word): if candidate in self.dictionary: freq = self.word_frequency.get(candidate, 0) candidates.append((0, -freq, candidate)) # Priority 2: edit distance 1 for candidate in self._edits1(word): if candidate in self.dictionary: freq = self.word_frequency.get(candidate, 0) candidates.append((1, -freq, candidate)) # Priority 3: edit distance 2 (only if needed) if len(candidates) < max_suggestions: for candidate in self._edits2(word): if candidate in self.dictionary: freq = self.word_frequency.get(candidate, 0) candidates.append((2, -freq, candidate)) seen = set() result = [] for _, _, w in sorted(candidates): if w not in seen: seen.add(w) result.append(w) if len(result) >= max_suggestions: break return result
Question 29 - B-Tree
Difficulty: 10 / 10
Approximate lines of code: 90 LoC
Tags: data-structures, storage
Description
A B-tree is a self-balancing search tree optimized for systems that read and write large blocks of data, like databases and filesystems. Unlike binary trees where each node has 2 children, B-tree nodes have many keys and children (determined by the “order” or “minimum degree”). This reduces tree height, minimizing disk I/O operations.
Each node contains sorted keys and associated values. For a B-tree of order t (minimum degree): each non-root node has between t-1 and 2t-1 keys, and between t and 2t children. A node with k keys has k+1 children. All leaves are at the same depth. The tree grows upward (from the root) rather than downward.
Part A: Structure and Search
Problem: Part A
Create a B-tree node that holds multiple keys, values, and child pointers. Implement search(key) that traverses from root to leaf.
# B-tree with order=3 (max 2 keys per node, max 3 children)# Internal state example:## [10, 20]# / | \# [5, 7] [12, 15] [25, 30]## Each node: keys=[...], values=[...], children=[...]tree = BTree(order=3)# After inserting keys 5,7,10,12,15,20,25,30:tree.search(12) # Returns "val_12"# Path: root[10,20] -> key 12 > 10 and < 20 -> middle child [12,15] -> found at index 0tree.search(8) # Returns None# Path: root[10,20] -> key 8 < 10 -> left child [5,7] -> not found, is leaf -> None
Search algorithm:
At current node, find first key >= search_key (binary or linear search)
If found exact match, return the value
If leaf node, return None (not found)
Otherwise, recurse into the appropriate child
Part B: Insertion with Splitting
Problem: Part B
Implement insert(key, value). The challenge: when a node is full, you must split it before inserting. Use proactive splitting - split full nodes on the way down, not after insertion.
# Starting state (order=3, max 2 keys per node):# [10]# / \# [5] [20, 30] <- This node is full!tree.insert(25, "val_25")# Step 1: Root not full, descend# Step 2: Right child [20,30] is full, split it BEFORE descending:# - Middle key (20 or 30 depending on implementation) promoted to parent# - Node splits into two smaller nodes# After split:# [10, 25]# / | \# [5] [20] [30]# Step 3: Insert 25 - wait, 25 was the split key promoted, so we're done# (or insert into appropriate child if different split point)
Split mechanics for node with keys [a, b, c] (order=3, mid=1):
Left node keeps: [a]
Right node gets: [c]
Middle key b promoted to parent
Children arrays split similarly (left gets first half, right gets second half)
Part C: Root Growth
Problem: Part C
When the root itself is full and needs to split, the tree grows taller. This is the ONLY way a B-tree increases height.
# Before (root is full, order=3):# [10, 20] <- Full root!tree.insert(15, "val_15")# Step 1: Root is full - create NEW root with old root as only child# [empty new root]# |# [10, 20]# Step 2: Split the child (old root)# [10] <- new root now has promoted key# / \# [] [20] <- split creates two children# Step 3: Insert 15 into appropriate child# [10]# / \# [] [15, 20]
Interview comments
Interview comments
Edge cases to probe:
What happens when inserting into an empty tree? (Create root with one key)
When exactly do you split - before or after inserting? (Before - proactive splitting)
How many children does a node with k keys have? (k+1)
What’s the minimum/maximum keys in a node? (t-1 to 2t-1 for non-root, root can have 1 to 2t-1)
Common mistakes:
Off-by-one errors in split: children array has one more element than keys array
Forgetting to insert the new child reference in parent after split
Splitting at wrong index (need to handle even vs odd number of keys)
Not creating new root when current root is full (just splitting in place)
Confusing order vs minimum degree terminology (order = max children = 2t)
Code solutions
Code solutions
Solution 1 is a basic recursive implementation using “order” terminology with linear search within nodes. Solution 2 uses minimum degree t (CLRS style) with binary search for key lookup and an explicit leaf flag. Solution 3 adds generic typing and update-on-duplicate semantics, making the B-tree behave like an ordered map. The key differences are in terminology conventions (order vs degree), search strategy (linear vs binary), and duplicate handling (error vs update). Core techniques: binary search, recursive tree traversal, proactive node splitting.
Solution 1: Recursive with order terminology
Basic recursive implementation using “order” terminology (order = max children). Uses linear search within nodes and recursive insert. Clean separation of split logic into a helper method.
"""B-Tree implementation with configurable order (branching factor)."""from dataclasses import dataclass, fieldfrom typing import Optional, Any@dataclassclass BTreeNode: """A node in the B-tree.""" keys: list[Any] = field(default_factory=list) values: list[Any] = field(default_factory=list) children: list["BTreeNode"] = field(default_factory=list) @property def is_leaf(self) -> bool: return len(self.children) == 0@dataclassclass BTree: """B-Tree with configurable order (max children per node).""" order: int # Maximum number of children per node root: BTreeNode = field(default_factory=BTreeNode) @property def max_keys(self) -> int: return self.order - 1 def search(self, key: Any) -> Optional[Any]: """Search for a key and return its value, or None if not found.""" return self._search_node(self.root, key) def _search_node(self, node: BTreeNode, key: Any) -> Optional[Any]: i = 0 while i < len(node.keys) and key > node.keys[i]: i += 1 if i < len(node.keys) and key == node.keys[i]: return node.values[i] if node.is_leaf: return None return self._search_node(node.children[i], key) def insert(self, key: Any, value: Any) -> None: """Insert a key-value pair into the B-tree.""" root = self.root if len(root.keys) == self.max_keys: new_root = BTreeNode(children=[root]) self._split_child(new_root, 0) self.root = new_root self._insert_non_full(self.root, key, value) def _insert_non_full(self, node: BTreeNode, key: Any, value: Any) -> None: i = len(node.keys) - 1 if node.is_leaf: # Find position and insert while i >= 0 and key < node.keys[i]: i -= 1 node.keys.insert(i + 1, key) node.values.insert(i + 1, value) else: while i >= 0 and key < node.keys[i]: i -= 1 i += 1 if len(node.children[i].keys) == self.max_keys: self._split_child(node, i) if key > node.keys[i]: i += 1 self._insert_non_full(node.children[i], key, value) def _split_child(self, parent: BTreeNode, index: int) -> None: """Split a full child node at the given index.""" full_node = parent.children[index] mid = self.max_keys // 2 new_node = BTreeNode( keys=full_node.keys[mid + 1:], values=full_node.values[mid + 1:], children=full_node.children[mid + 1:] if not full_node.is_leaf else [] ) parent.keys.insert(index, full_node.keys[mid]) parent.values.insert(index, full_node.values[mid]) parent.children.insert(index + 1, new_node) full_node.keys = full_node.keys[:mid] full_node.values = full_node.values[:mid] if not full_node.is_leaf: full_node.children = full_node.children[:mid + 1]
Solution 2: CLRS style with binary search
Uses minimum degree t terminology (standard in CLRS). Includes iterative search with binary search for key lookup within nodes. Explicit leaf boolean flag instead of checking children length.
"""B-Tree with iterative search and explicit min/max degree handling."""from dataclasses import dataclass, fieldfrom typing import Optional, Any, Tuple@dataclassclass Node: """B-tree node with keys, values, and optional children.""" keys: list[Any] = field(default_factory=list) values: list[Any] = field(default_factory=list) children: list["Node"] = field(default_factory=list) leaf: bool = True@dataclassclass BTree: """B-Tree using minimum degree t (nodes have t-1 to 2t-1 keys).""" t: int # Minimum degree def __post_init__(self) -> None: self.root = Node() @property def max_keys(self) -> int: return 2 * self.t - 1 def search(self, key: Any) -> Optional[Any]: """Iterative search for key.""" node = self.root while node: i = self._find_key_index(node, key) if i < len(node.keys) and node.keys[i] == key: return node.values[i] if node.leaf: return None node = node.children[i] return None def _find_key_index(self, node: Node, key: Any) -> int: """Binary search for key position in node.""" lo, hi = 0, len(node.keys) while lo < hi: mid = (lo + hi) // 2 if node.keys[mid] < key: lo = mid + 1 else: hi = mid return lo def insert(self, key: Any, value: Any) -> None: """Insert key-value, splitting root if necessary.""" if len(self.root.keys) == self.max_keys: old_root = self.root self.root = Node(children=[old_root], leaf=False) self._split_child(self.root, 0) self._insert_non_full(self.root, key, value) def _split_child(self, parent: Node, idx: int) -> None: """Split parent.children[idx] which must be full.""" t = self.t full = parent.children[idx] new = Node(leaf=full.leaf) # Move upper half to new node new.keys = full.keys[t:] new.values = full.values[t:] if not full.leaf: new.children = full.children[t:] # Promote middle key to parent parent.keys.insert(idx, full.keys[t - 1]) parent.values.insert(idx, full.values[t - 1]) parent.children.insert(idx + 1, new) # Trim the original node full.keys = full.keys[:t - 1] full.values = full.values[:t - 1] if not full.leaf: full.children = full.children[:t] def _insert_non_full(self, node: Node, key: Any, value: Any) -> None: """Insert into a node that has room.""" i = self._find_key_index(node, key) if node.leaf: node.keys.insert(i, key) node.values.insert(i, value) else: if len(node.children[i].keys) == self.max_keys: self._split_child(node, i) if key > node.keys[i]: i += 1 self._insert_non_full(node.children[i], key, value)
Solution 3: Generic typed with update semantics
Generic typed implementation with update-on-duplicate semantics (like a map). Checks if key exists before inserting, and updates the value if found. Shows how B-trees can serve as ordered maps.
"""B-Tree with generic typing and update-on-duplicate semantics."""from dataclasses import dataclass, fieldfrom typing import Generic, TypeVar, OptionalK = TypeVar("K")V = TypeVar("V")@dataclassclass BNode(Generic[K, V]): """Generic B-tree node.""" keys: list[K] = field(default_factory=list) values: list[V] = field(default_factory=list) children: list["BNode[K, V]"] = field(default_factory=list) @property def is_leaf(self) -> bool: return len(self.children) == 0@dataclassclass BTree(Generic[K, V]): """Generic B-Tree with update semantics for duplicate keys.""" order: int root: BNode[K, V] = field(default_factory=BNode) def get(self, key: K) -> Optional[V]: """Get value for key, or None.""" node, idx = self._find(self.root, key) return node.values[idx] if node else None def _find(self, node: BNode[K, V], key: K) -> tuple[Optional[BNode[K, V]], int]: """Find node and index containing key.""" i = 0 while i < len(node.keys) and key > node.keys[i]: i += 1 if i < len(node.keys) and node.keys[i] == key: return node, i if node.is_leaf: return None, -1 return self._find(node.children[i], key) def put(self, key: K, value: V) -> None: """Insert or update key-value pair.""" # Check for existing key first node, idx = self._find(self.root, key) if node: node.values[idx] = value return # Insert new key if len(self.root.keys) == self.order - 1: old = self.root self.root = BNode(children=[old]) self._split(self.root, 0) self._insert(self.root, key, value) def _insert(self, node: BNode[K, V], key: K, value: V) -> None: i = len(node.keys) - 1 if node.is_leaf: while i >= 0 and key < node.keys[i]: i -= 1 node.keys.insert(i + 1, key) node.values.insert(i + 1, value) else: while i >= 0 and key < node.keys[i]: i -= 1 i += 1 if len(node.children[i].keys) == self.order - 1: self._split(node, i) if key > node.keys[i]: i += 1 self._insert(node.children[i], key, value) def _split(self, parent: BNode[K, V], idx: int) -> None: mid = (self.order - 1) // 2 child = parent.children[idx] sibling = BNode[K, V]( keys=child.keys[mid + 1:], values=child.values[mid + 1:], children=child.children[mid + 1:] if not child.is_leaf else [] ) parent.keys.insert(idx, child.keys[mid]) parent.values.insert(idx, child.values[mid]) parent.children.insert(idx + 1, sibling) child.keys = child.keys[:mid] child.values = child.values[:mid] if not child.is_leaf: child.children = child.children[:mid + 1]
Question 30 - Transaction Log
Difficulty: 4 / 10
Approximate lines of code: 110 LoC
Tags: storage
Description
A transaction log provides atomicity for a key-value store: operations either fully complete or fully roll back, with no partial states visible. The key mechanism is storing the old value before each modification, enabling rollback by replaying these saved values in reverse order. This is the foundation of ACID transactions in databases.
Internal state includes: the actual data store (a dict), a write-ahead log (WAL) of operations with their old values, and savepoint markers. When you begin(), you start recording. When you commit(), you discard the log (changes are permanent). When you rollback(), you replay the log backwards to restore the original state.
Part A: Begin, Commit, Rollback
Problem: Part A
Implement basic transaction semantics with begin(), set(key, value), delete(key), commit(), and rollback().
tm = TransactionManager()# Basic transaction flowtm.begin()tm.set("user", "alice")tm.set("count", 1)# Internal WAL: [SET user (old=None), SET count (old=None)]# Store: {"user": "alice", "count": 1}tm.commit()# WAL cleared, changes are permanent# Store: {"user": "alice", "count": 1}# Rollback exampletm.begin()tm.set("user", "bob")# WAL: [SET user (old="alice")]# Store: {"user": "bob", "count": 1}tm.rollback()# Replay WAL backwards: restore user to "alice"# Store: {"user": "alice", "count": 1}
Part B: Savepoints
Problem: Part B
Add nested rollback points within a transaction. savepoint() returns a marker you can roll back to without aborting the entire transaction.
tm.begin()tm.set("x", 10)sp1 = tm.savepoint() # savepoint 1# WAL: [SET x (old=None, sp=0)]tm.set("x", 20)sp2 = tm.savepoint() # savepoint 2# WAL: [SET x (old=None, sp=0), SET x (old=10, sp=1)]tm.set("x", 30)tm.set("y", 100)# WAL: [..., SET x (old=20, sp=2), SET y (old=None, sp=2)]tm.rollback_to(sp2)# Undo all ops with savepoint >= 2# x restored to 20, y removedassert tm.get("x") == 20assert tm.get("y") is Nonetm.rollback_to(sp1)assert tm.get("x") == 10tm.commit()
Part C: Durability via Write-Ahead Log
Problem: Part C
Make transactions survive crashes by writing the log to disk before applying changes. On startup, replay any uncommitted transactions from the log.
# Write to disk before modifying memorytm.begin()tm.set("key", "value")# 1. Write to WAL file: {"op": "SET", "key": "key", "value": "value"}# 2. fsync() the file# 3. Apply to in-memory storetm.commit()# 1. Write COMMIT record to WAL# 2. fsync()# 3. Delete WAL file (or truncate)# Crash recovery on startup:# 1. Read WAL file# 2. If COMMIT record exists: replay all ops# 3. If no COMMIT: discard (transaction was incomplete)
Interview comments
Interview comments
Edge cases to probe:
What happens if you call set() outside a transaction?
What happens if you roll back to an invalid savepoint?
How do you handle delete of a non-existent key?
What if the process crashes between writing to WAL and applying to store?
Common mistakes:
Shallow copy of values when storing old state (mutable values get corrupted)
Not storing the old value for delete operations (can’t restore on rollback)
Replaying WAL in forward order instead of reverse during rollback
Not using fsync for durability (just flush is not enough)
Code solutions
Code solutions
Solution 1 is a simple in-memory implementation using a log entry list that stores operation type, key, old/new values, and savepoint number. Solution 2 adds persistence with a write-ahead log (WAL) that writes JSON entries to disk with fsync for crash recovery. Solution 3 uses the command pattern where each operation (SetCommand, DeleteCommand) knows how to execute and undo itself. The key difference is durability vs simplicity: Solution 1 is memory-only, Solution 2 survives crashes, and Solution 3 provides clean separation of concerns. Core techniques: write-ahead logging, reverse-order rollback, savepoint markers, command pattern with undo stack, fsync for durability.
Solution 1: Simple In-Memory Implementation
Simple in-memory implementation with a log entry list. Each entry stores the operation type, key, old value, new value, and savepoint number. Rollback pops entries from the end and restores old values. Clean and correct but not persistent.
"""Transaction Log - Solution 1: Simple In-Memory ImplementationBasic ACID transactions with write-ahead logging and savepoints."""from dataclasses import dataclass, fieldfrom enum import Enumfrom typing import Anyclass OpType(Enum): SET = "SET" DELETE = "DELETE"@dataclassclass LogEntry: op: OpType key: str old_value: Any new_value: Any savepoint: int = 0@dataclassclass TransactionManager: store: dict[str, Any] = field(default_factory=dict) wal: list[LogEntry] = field(default_factory=list) in_transaction: bool = False savepoint_counter: int = 0 def begin(self) -> None: if self.in_transaction: raise RuntimeError("Transaction already active") self.in_transaction = True self.wal.clear() self.savepoint_counter = 0 def set(self, key: str, value: Any) -> None: self._require_transaction() old = self.store.get(key) self.wal.append(LogEntry(OpType.SET, key, old, value, self.savepoint_counter)) self.store[key] = value def delete(self, key: str) -> None: self._require_transaction() if key not in self.store: raise KeyError(f"Key not found: {key}") old = self.store[key] self.wal.append(LogEntry(OpType.DELETE, key, old, None, self.savepoint_counter)) del self.store[key] def get(self, key: str) -> Any: return self.store.get(key) def savepoint(self) -> int: self._require_transaction() self.savepoint_counter += 1 return self.savepoint_counter def rollback_to(self, sp: int) -> None: self._require_transaction() while self.wal and self.wal[-1].savepoint >= sp: entry = self.wal.pop() if entry.op == OpType.SET: if entry.old_value is None: self.store.pop(entry.key, None) else: self.store[entry.key] = entry.old_value elif entry.op == OpType.DELETE: self.store[entry.key] = entry.old_value self.savepoint_counter = sp - 1 def rollback(self) -> None: self._require_transaction() self.rollback_to(0) self.in_transaction = False def commit(self) -> None: self._require_transaction() self.wal.clear() self.in_transaction = False def _require_transaction(self) -> None: if not self.in_transaction: raise RuntimeError("No active transaction")
Solution 2: Persistent WAL with Recovery
Persistent WAL with crash recovery. Writes JSON log entries to disk with fsync for durability. On startup, replays the log to recover state. Savepoints use a stack of state snapshots for simpler rollback logic.
"""Transaction Log - Solution 2: Persistent WAL with RecoveryWrites log entries to disk for crash recovery."""from dataclasses import dataclassfrom enum import Enumfrom pathlib import Pathfrom typing import Anyimport jsonimport osclass Op(Enum): BEGIN = "BEGIN" SET = "SET" DELETE = "DELETE" COMMIT = "COMMIT" SAVEPOINT = "SAVEPOINT" ROLLBACK_TO = "ROLLBACK_TO"@dataclassclass PersistentTxManager: wal_path: Path store: dict[str, Any] savepoint_stack: list[dict[str, Any]] active: bool = False def __init__(self, wal_path: str = "wal.log"): self.wal_path = Path(wal_path) self.store = {} self.savepoint_stack = [] self.active = False self._recover() def _write_log(self, op: Op, key: str = "", value: Any = None) -> None: entry = {"op": op.value, "key": key, "value": value} with open(self.wal_path, "a") as f: f.write(json.dumps(entry) + "\n") f.flush() os.fsync(f.fileno()) def _recover(self) -> None: if not self.wal_path.exists(): return with open(self.wal_path) as f: entries = [json.loads(line) for line in f if line.strip()] temp_store: dict[str, Any] = {} committed = False for e in entries: op = Op(e["op"]) if op == Op.BEGIN: temp_store = dict(self.store) committed = False elif op == Op.SET: temp_store[e["key"]] = e["value"] elif op == Op.DELETE: temp_store.pop(e["key"], None) elif op == Op.COMMIT: self.store = dict(temp_store) committed = True if committed: self.wal_path.unlink(missing_ok=True) def begin(self) -> None: if self.active: raise RuntimeError("Transaction already active") self._write_log(Op.BEGIN) self.active = True self.savepoint_stack = [dict(self.store)] def set(self, key: str, value: Any) -> None: if not self.active: raise RuntimeError("No active transaction") self._write_log(Op.SET, key, value) self.store[key] = value def delete(self, key: str) -> None: if not self.active: raise RuntimeError("No active transaction") self._write_log(Op.DELETE, key) self.store.pop(key, None) def get(self, key: str) -> Any: return self.store.get(key) def savepoint(self) -> int: if not self.active: raise RuntimeError("No active transaction") self.savepoint_stack.append(dict(self.store)) self._write_log(Op.SAVEPOINT) return len(self.savepoint_stack) - 1 def rollback_to(self, sp: int) -> None: if not self.active or sp >= len(self.savepoint_stack): raise RuntimeError("Invalid savepoint") self._write_log(Op.ROLLBACK_TO, value=sp) self.store = dict(self.savepoint_stack[sp]) self.savepoint_stack = self.savepoint_stack[:sp + 1] def rollback(self) -> None: if not self.active: raise RuntimeError("No active transaction") self.store = dict(self.savepoint_stack[0]) self.savepoint_stack.clear() self.active = False self.wal_path.unlink(missing_ok=True) def commit(self) -> None: if not self.active: raise RuntimeError("No active transaction") self._write_log(Op.COMMIT) self.savepoint_stack.clear() self.active = False self.wal_path.unlink(missing_ok=True)
Solution 3: Command Pattern with Undo Stack
Command pattern with explicit undo stack. Each operation (SetCommand, DeleteCommand) knows how to execute and undo itself. Transactions manage a list of commands and savepoints as indices into that list. Clean separation of concerns.
An indexed priority queue (IPQ) combines a priority queue with O(1) key lookup, enabling efficient priority updates. This is essential for graph algorithms like Dijkstra’s shortest path, where you need to decrease the priority of a vertex when you discover a shorter path. A standard heap gives O(log n) insert and extract-min, but updating a priority requires O(n) to find the element first.
The key insight is maintaining a hash map from keys to their positions in the heap array. Every time you swap elements during bubble-up or bubble-down, you must also update both keys’ positions in the hash map. This gives O(1) lookup of any key’s heap position, enabling O(log n) priority updates.
Part A: Basic Operations
Problem: Part A
Implement insert(key, priority), extract_min(), and peek_min(). The key is a unique identifier (like a vertex ID). Store entries as (key, priority) pairs in a heap array, with a hash map tracking key → array index. Duplicate keys should raise an error.
pq = IndexedPriorityQueue()pq.insert("A", 10)pq.insert("B", 5)pq.insert("C", 15)# Internal state:# _heap = [Entry("B", 5), Entry("A", 10), Entry("C", 15)]# _key_to_index = {"B": 0, "A": 1, "C": 2}pq.peek_min() # ("B", 5)pq.extract_min() # ("B", 5)# After extract_min:# _heap = [Entry("A", 10), Entry("C", 15)] # C moved to root, bubbled down# _key_to_index = {"A": 0, "C": 1}pq.insert("B", 5) # OK, B was removedpq.insert("A", 20) # KeyError: "A" already exists
Part B: Update and Remove
Problem: Part B
Add update_priority(key, new_priority) and remove(key). After updating a priority, the element might need to bubble up (if priority decreased) OR bubble down (if priority increased). The hash map enables O(1) lookup of the element’s current position.
pq = IndexedPriorityQueue()pq.insert("A", 10)pq.insert("B", 5)pq.insert("C", 15)# _heap = [Entry("B", 5), Entry("A", 10), Entry("C", 15)]# _key_to_index = {"B": 0, "A": 1, "C": 2}pq.update_priority("C", 3) # C now has lowest priority# After update:# 1. Find C at index 2# 2. Change priority: _heap[2].priority = 3# 3. Since 3 < 15, bubble UP# 4. Compare with parent at (2-1)//2 = 0, swap with B# _heap = [Entry("C", 3), Entry("A", 10), Entry("B", 5)]# Wait, that's wrong - need to continue bubbling B down# Correct state after bubbling:# _heap = [Entry("C", 3), Entry("A", 10), Entry("B", 5)]# No wait: after C bubbles to root, B is at index 2# _key_to_index = {"C": 0, "A": 1, "B": 2}pq.remove("A") # Remove from middle# Swap A with last element, pop, then bubble the swapped element
Part C: Edge Cases
Problem: Part C
Handle operations on empty queue, non-existent keys, and implement decrease_key(key, new_priority) that only updates if the new priority is strictly lower (commonly needed for Dijkstra’s).
pq = IndexedPriorityQueue()pq.extract_min() # IndexError: "Priority queue is empty"pq.peek_min() # IndexError: "Priority queue is empty"pq.update_priority("X", 5) # KeyError: "X" not foundpq.remove("X") # KeyError: "X" not foundpq.insert("A", 10)pq.decrease_key("A", 15) # ValueError: new priority 15 is not less than 10pq.decrease_key("A", 5) # OK, updates to 5"A" in pq # True (implement __contains__)len(pq) # 1 (implement __len__)
Interview comments
Interview comments
Edge cases to probe:
What happens when you update a priority and need to bubble down instead of up?
How do you handle removing an element from the middle of the heap?
What’s the parent/child index formula? (parent: (i-1)//2, children: 2*i+1, 2*i+2)
What if someone tries to insert a duplicate key?
Common mistakes:
Forgetting to update the hash map on EVERY swap during bubble up/down
Only bubbling one direction after update (might need up OR down)
Off-by-one in parent/child calculations: parent is (i-1)//2 not i//2
Not handling remove from middle correctly (must swap with last, pop, then bubble)
Using > instead of >= in bubble comparisons (affects stability)
Code solutions
Code solutions
Solution 1 uses a binary heap with a hash map for O(1) key lookup, maintaining the index map on every swap for O(log n) priority updates. Solution 2 uses Python’s heapq with lazy deletion, marking old entries as invalid instead of physically removing them. Solution 3 is a simple O(n) dictionary-based approach with linear scans for min/max. They differ in complexity vs. simplicity: optimal heap operations vs. simpler lazy deletion vs. trivial but slower linear scan.
Solution 1: Binary heap with hash map
Binary heap with hash map for O(1) key lookup. Maintains the index map on every swap. Explicit bubble-up and bubble-down methods.
"""Indexed Priority Queue - Solution 1: Binary Heap with Hash MapUses a min-heap with a dictionary for O(1) key lookups and O(log n) updates."""from dataclasses import dataclassfrom typing import Generic, TypeVar, OptionalK = TypeVar('K')V = TypeVar('V')@dataclassclass HeapEntry(Generic[K, V]): key: K priority: Vclass IndexedPriorityQueue(Generic[K, V]): def __init__(self) -> None: self._heap: list[HeapEntry[K, V]] = [] self._key_to_index: dict[K, int] = {} def __len__(self) -> int: return len(self._heap) def __contains__(self, key: K) -> bool: return key in self._key_to_index def insert(self, key: K, priority: V) -> None: if key in self._key_to_index: raise KeyError(f"Key {key} already exists") index = len(self._heap) self._heap.append(HeapEntry(key, priority)) self._key_to_index[key] = index self._bubble_up(index) def extract_min(self) -> tuple[K, V]: if not self._heap: raise IndexError("Priority queue is empty") min_entry = self._heap[0] self._remove_at(0) return min_entry.key, min_entry.priority def peek_min(self) -> tuple[K, V]: if not self._heap: raise IndexError("Priority queue is empty") return self._heap[0].key, self._heap[0].priority def update_priority(self, key: K, new_priority: V) -> None: if key not in self._key_to_index: raise KeyError(f"Key {key} not found") index = self._key_to_index[key] old_priority = self._heap[index].priority self._heap[index].priority = new_priority if new_priority < old_priority: self._bubble_up(index) else: self._bubble_down(index) def _bubble_up(self, index: int) -> None: while index > 0: parent = (index - 1) // 2 if self._heap[index].priority >= self._heap[parent].priority: break self._swap(index, parent) index = parent def _bubble_down(self, index: int) -> None: size = len(self._heap) while True: smallest = index left, right = 2 * index + 1, 2 * index + 2 if left < size and self._heap[left].priority < self._heap[smallest].priority: smallest = left if right < size and self._heap[right].priority < self._heap[smallest].priority: smallest = right if smallest == index: break self._swap(index, smallest) index = smallest def _swap(self, i: int, j: int) -> None: self._key_to_index[self._heap[i].key] = j self._key_to_index[self._heap[j].key] = i self._heap[i], self._heap[j] = self._heap[j], self._heap[i] def _remove_at(self, index: int) -> None: del self._key_to_index[self._heap[index].key] if index == len(self._heap) - 1: self._heap.pop() else: last = self._heap.pop() self._heap[index] = last self._key_to_index[last.key] = index self._bubble_up(index) self._bubble_down(index)
Solution 2: Using heapq with lazy deletion
Lazy deletion approach using Python’s heapq. Updates mark old entries as invalid and push new ones. Invalid entries are skipped during extraction. Simpler but uses more memory.
"""Indexed Priority Queue - Solution 2: Using heapq with Lazy DeletionSimpler implementation using Python's heapq. Deleted/updated entries aremarked invalid and skipped during extraction."""from dataclasses import dataclass, fieldfrom typing import Generic, TypeVar, Optionalimport heapqK = TypeVar('K')V = TypeVar('V')REMOVED = object()@dataclass(order=True)class Entry(Generic[K, V]): priority: V key: K = field(compare=False) valid: bool = field(default=True, compare=False)class IndexedPriorityQueue(Generic[K, V]): def __init__(self) -> None: self._heap: list[Entry[K, V]] = [] self._entries: dict[K, Entry[K, V]] = {} def __len__(self) -> int: return len(self._entries) def __contains__(self, key: K) -> bool: return key in self._entries def insert(self, key: K, priority: V) -> None: if key in self._entries: raise KeyError(f"Key {key} already exists") entry = Entry(priority=priority, key=key) self._entries[key] = entry heapq.heappush(self._heap, entry) def extract_min(self) -> tuple[K, V]: while self._heap: entry = heapq.heappop(self._heap) if entry.valid: del self._entries[entry.key] return entry.key, entry.priority raise IndexError("Priority queue is empty") def peek_min(self) -> tuple[K, V]: while self._heap and not self._heap[0].valid: heapq.heappop(self._heap) if not self._heap: raise IndexError("Priority queue is empty") entry = self._heap[0] return entry.key, entry.priority def update_priority(self, key: K, new_priority: V) -> None: if key not in self._entries: raise KeyError(f"Key {key} not found") old_entry = self._entries[key] old_entry.valid = False new_entry = Entry(priority=new_priority, key=key) self._entries[key] = new_entry heapq.heappush(self._heap, new_entry) def remove(self, key: K) -> V: if key not in self._entries: raise KeyError(f"Key {key} not found") entry = self._entries.pop(key) entry.valid = False return entry.priority def get_priority(self, key: K) -> V: if key not in self._entries: raise KeyError(f"Key {key} not found") return self._entries[key].priority
Solution 3: Dict-based with linear scan
Simple O(n) implementation using a dictionary. Linear scan for min. Useful as a baseline or when n is small. Easy to understand and extend.
"""Indexed Priority Queue - Solution 3: Dict-based with Linear ScanSimple O(n) implementation - useful as a baseline or when n is small.Easy to understand and debug, minimal code."""from dataclasses import dataclassfrom typing import Generic, TypeVar, OptionalK = TypeVar('K')V = TypeVar('V')@dataclassclass IndexedPriorityQueue(Generic[K, V]): _data: dict[K, V] = None def __post_init__(self) -> None: if self._data is None: self._data = {} def __len__(self) -> int: return len(self._data) def __contains__(self, key: K) -> bool: return key in self._data def insert(self, key: K, priority: V) -> None: if key in self._data: raise KeyError(f"Key {key} already exists") self._data[key] = priority def extract_min(self) -> tuple[K, V]: if not self._data: raise IndexError("Priority queue is empty") min_key = min(self._data, key=lambda k: self._data[k]) priority = self._data.pop(min_key) return min_key, priority def extract_max(self) -> tuple[K, V]: if not self._data: raise IndexError("Priority queue is empty") max_key = max(self._data, key=lambda k: self._data[k]) priority = self._data.pop(max_key) return max_key, priority def peek_min(self) -> tuple[K, V]: if not self._data: raise IndexError("Priority queue is empty") min_key = min(self._data, key=lambda k: self._data[k]) return min_key, self._data[min_key] def peek_max(self) -> tuple[K, V]: if not self._data: raise IndexError("Priority queue is empty") max_key = max(self._data, key=lambda k: self._data[k]) return max_key, self._data[max_key] def update_priority(self, key: K, new_priority: V) -> None: if key not in self._data: raise KeyError(f"Key {key} not found") self._data[key] = new_priority def decrease_key(self, key: K, new_priority: V) -> None: if key not in self._data: raise KeyError(f"Key {key} not found") if new_priority > self._data[key]: raise ValueError("New priority must be less than current") self._data[key] = new_priority def increase_key(self, key: K, new_priority: V) -> None: if key not in self._data: raise KeyError(f"Key {key} not found") if new_priority < self._data[key]: raise ValueError("New priority must be greater than current") self._data[key] = new_priority def remove(self, key: K) -> V: if key not in self._data: raise KeyError(f"Key {key} not found") return self._data.pop(key) def get_priority(self, key: K) -> V: if key not in self._data: raise KeyError(f"Key {key} not found") return self._data[key]
Question 32 - Game of Life
Difficulty: 2 / 10
Approximate lines of code: 100 LoC
Tags: game/simulation
Description
Conway’s Game of Life is a cellular automaton on a 2D grid where each cell is either alive or dead. Cells evolve simultaneously according to four rules based on neighbor count (the 8 adjacent cells):
Underpopulation: A live cell with < 2 neighbors dies
Survival: A live cell with 2-3 neighbors survives
Overpopulation: A live cell with > 3 neighbors dies
Reproduction: A dead cell with exactly 3 neighbors becomes alive
The critical implementation detail is that all cells update simultaneously - you cannot modify the grid in-place while iterating. You need either a second buffer or a way to encode both old and new states.
Part A: Basic Simulation
Problem: Part A
Implement a fixed-size grid with set_alive(cells) to initialize and step() to advance one generation. All cells must update simultaneously based on the previous state.
game = GameOfLife(rows=5, cols=5)# Initialize a "blinker" - oscillates between horizontal and verticalgame.set_alive({(2, 1), (2, 2), (2, 3)})# Internal state:# .....# .....# .###. <- Three horizontal cells# .....# .....game.step()# After one step:# .....# ..#..# ..#.. <- Three vertical cells# ..#..# .....game.step()# Returns to original horizontal patternassert game.get_live_cells() == {(2, 1), (2, 2), (2, 3)}
Part B: Boundary Handling
Problem: Part B
Handle grid edges. Two common approaches:
Toroidal (wrap-around): Cells on opposite edges are neighbors
Dead boundary: Cells outside the grid are always dead
# Wrap-around example on 3x3 gridgame = GameOfLife(rows=3, cols=3, wrap=True)game.set_alive({(0, 0), (0, 1), (0, 2)}) # Top row# With wrap-around, top row neighbors include bottom rowgame.step()assert (2, 1) in game.get_live_cells() # Cell born at bottom# Dead boundary examplegame2 = GameOfLife(rows=3, cols=3, wrap=False)game2.set_alive({(0, 0), (0, 1), (0, 2)})game2.step()# (2, 1) not born because top row has no neighbors wrapping
Part C: Known Patterns and Infinite Grid
Problem: Part C
Implement support for known patterns and discuss infinite grid representation using sparse sets.
# Still life - "Block" (stable, doesn't change)block = {(1, 1), (1, 2), (2, 1), (2, 2)}game.set_alive(block)game.step()assert game.get_live_cells() == block# Glider - moves diagonally over timeglider = {(0, 1), (1, 2), (2, 0), (2, 1), (2, 2)}game.set_alive(glider)for _ in range(4): game.step()# After 4 generations, glider has moved (1, 1) diagonallyexpected = {(r + 1, c + 1) for r, c in glider}assert game.get_live_cells() == expected# Infinite grid: Use set of (row, col) tuples instead of 2D array# Only store live cells, check neighbors on demand
Interview comments
Interview comments
Edge cases to probe:
What happens to a single isolated cell? (Dies from underpopulation)
What if you forget to use a buffer? (Earlier updates affect later cells incorrectly)
How do you count neighbors efficiently? (8 directions, handle bounds)
What’s the complexity of step() for a dense grid vs. sparse grid?
Common mistakes:
In-place updates instead of using a buffer (the #1 bug)
Including the center cell in the neighbor count
Off-by-one errors in wrap-around modulo arithmetic
Checking only live cells for reproduction (must also check dead neighbors of live cells)
Not handling empty grid (should remain empty)
Code solutions
Code solutions
Solution 1 uses a fixed-size 2D list with wrap-around (toroidal) topology and creates a new grid each step to avoid in-place mutation. Solution 2 uses a sparse set representation for infinite grids, storing only live cells and using Counter for neighbor counting. Solution 3 uses NumPy arrays with vectorized operations and np.roll for neighbor counting, supporting both wrap-around and dead boundary modes. The key differences are grid representation (dense array vs sparse set) and boundary handling (toroidal vs infinite vs configurable). Core techniques: double buffering, modular arithmetic for wrap-around, sparse set with neighbor counting, vectorized array operations.
Solution 1: Fixed Grid with Wrap-around
Fixed grid using 2D list with wrap-around (toroidal topology). Uses modular arithmetic for neighbor access. Creates new grid each step to avoid in-place mutation.
"""Conway's Game of Life - Solution 1: Fixed Grid with Wrap-aroundUses a 2D list with modular arithmetic for toroidal topology."""from dataclasses import dataclass, fieldfrom typing import List, Set, Tuple@dataclassclass GameOfLife: """Fixed-size grid with wrap-around edges (toroidal topology).""" rows: int cols: int grid: List[List[bool]] = field(default_factory=list) def __post_init__(self) -> None: if not self.grid: self.grid = [[False] * self.cols for _ in range(self.rows)] def set_alive(self, cells: Set[Tuple[int, int]]) -> None: """Initialize live cells.""" for r, c in cells: self.grid[r % self.rows][c % self.cols] = True def count_neighbors(self, row: int, col: int) -> int: """Count live neighbors with wrap-around.""" count = 0 for dr in (-1, 0, 1): for dc in (-1, 0, 1): if dr == 0 and dc == 0: continue nr = (row + dr) % self.rows nc = (col + dc) % self.cols if self.grid[nr][nc]: count += 1 return count def step(self) -> None: """Advance simulation by one generation.""" new_grid = [[False] * self.cols for _ in range(self.rows)] for r in range(self.rows): for c in range(self.cols): neighbors = self.count_neighbors(r, c) alive = self.grid[r][c] # Rules: survive with 2-3 neighbors, born with exactly 3 if alive and neighbors in (2, 3): new_grid[r][c] = True elif not alive and neighbors == 3: new_grid[r][c] = True self.grid = new_grid def get_live_cells(self) -> Set[Tuple[int, int]]: """Return set of live cell coordinates.""" return {(r, c) for r in range(self.rows) for c in range(self.cols) if self.grid[r][c]} def display(self) -> str: """Return string representation of grid.""" return "\n".join("".join("#" if cell else "." for cell in row) for row in self.grid)if __name__ == "__main__": # Test 1: Blinker oscillator (period 2) game = GameOfLife(5, 5) game.set_alive({(2, 1), (2, 2), (2, 3)}) # Horizontal line initial = game.get_live_cells() game.step() after_one = game.get_live_cells() game.step() after_two = game.get_live_cells() assert after_one == {(1, 2), (2, 2), (3, 2)}, f"Blinker step 1 failed: {after_one}" assert after_two == initial, f"Blinker should return to initial: {after_two}" print("Blinker test passed!") # Test 2: Block still life (stable) game2 = GameOfLife(4, 4) block = {(1, 1), (1, 2), (2, 1), (2, 2)} game2.set_alive(block) game2.step() assert game2.get_live_cells() == block, "Block should be stable" print("Block test passed!") # Test 3: Wrap-around game3 = GameOfLife(3, 3) game3.set_alive({(0, 0), (0, 1), (0, 2)}) # Top row game3.step() assert (2, 1) in game3.get_live_cells(), "Wrap-around should create cell at bottom" print("Wrap-around test passed!") print("All tests passed!")
Solution 2: Infinite Grid with Sparse Representation
Infinite grid using sparse set representation. Only stores live cells. Uses Counter to efficiently track neighbor counts. Memory-efficient for sparse patterns like gliders.
"""Conway's Game of Life - Solution 2: Infinite Grid with Sparse RepresentationUses a set of live cell coordinates - efficient for sparse patterns."""from dataclasses import dataclass, fieldfrom typing import Set, Tuple, Dictfrom collections import Counter@dataclassclass InfiniteGameOfLife: """Infinite grid using sparse set representation.""" live_cells: Set[Tuple[int, int]] = field(default_factory=set) def set_alive(self, cells: Set[Tuple[int, int]]) -> None: """Initialize live cells.""" self.live_cells = set(cells) def _get_neighbors(self, row: int, col: int) -> Set[Tuple[int, int]]: """Return all 8 neighbors of a cell.""" return { (row + dr, col + dc) for dr in (-1, 0, 1) for dc in (-1, 0, 1) if not (dr == 0 and dc == 0) } def step(self) -> None: """Advance simulation by one generation using neighbor counting.""" # Count how many live neighbors each cell has neighbor_counts: Dict[Tuple[int, int], int] = Counter() for cell in self.live_cells: for neighbor in self._get_neighbors(*cell): neighbor_counts[neighbor] += 1 new_live: Set[Tuple[int, int]] = set() # Check all cells that could potentially be alive candidates = self.live_cells | set(neighbor_counts.keys()) for cell in candidates: count = neighbor_counts.get(cell, 0) alive = cell in self.live_cells if alive and count in (2, 3): new_live.add(cell) elif not alive and count == 3: new_live.add(cell) self.live_cells = new_live def get_live_cells(self) -> Set[Tuple[int, int]]: """Return set of live cell coordinates.""" return set(self.live_cells) def get_bounds(self) -> Tuple[int, int, int, int]: """Return (min_row, max_row, min_col, max_col) or (0,0,0,0) if empty.""" if not self.live_cells: return (0, 0, 0, 0) rows = [r for r, c in self.live_cells] cols = [c for r, c in self.live_cells] return (min(rows), max(rows), min(cols), max(cols)) def display(self, padding: int = 1) -> str: """Return string representation of bounding box with padding.""" if not self.live_cells: return "." min_r, max_r, min_c, max_c = self.get_bounds() lines = [] for r in range(min_r - padding, max_r + padding + 1): row = "" for c in range(min_c - padding, max_c + padding + 1): row += "#" if (r, c) in self.live_cells else "." lines.append(row) return "\n".join(lines)if __name__ == "__main__": # Test 1: Glider moves diagonally game = InfiniteGameOfLife() glider = {(0, 1), (1, 2), (2, 0), (2, 1), (2, 2)} game.set_alive(glider) for _ in range(4): game.step() # After 4 steps, glider moves down-right by (1, 1) expected = {(r + 1, c + 1) for r, c in glider} assert game.get_live_cells() == expected, f"Glider failed: {game.get_live_cells()}" print("Glider test passed!") # Test 2: Blinker game2 = InfiniteGameOfLife() game2.set_alive({(0, -1), (0, 0), (0, 1)}) game2.step() assert game2.get_live_cells() == {(-1, 0), (0, 0), (1, 0)}, "Blinker step 1 failed" game2.step() assert game2.get_live_cells() == {(0, -1), (0, 0), (0, 1)}, "Blinker step 2 failed" print("Blinker test passed!") # Test 3: Empty grid stays empty game3 = InfiniteGameOfLife() game3.step() assert game3.get_live_cells() == set(), "Empty grid should stay empty" print("Empty grid test passed!") # Test 4: Single cell dies game4 = InfiniteGameOfLife() game4.set_alive({(0, 0)}) game4.step() assert game4.get_live_cells() == set(), "Single cell should die" print("Single cell test passed!") print("All tests passed!")
Solution 3: NumPy-based for Performance
NumPy-based implementation for performance on large grids. Uses np.roll for neighbor counting and vectorized boolean operations. Supports both wrap-around and dead boundary modes.
"""Conway's Game of Life - Solution 3: NumPy-based for PerformanceUses 2D numpy array with manual neighbor counting via array slicing."""from dataclasses import dataclassfrom typing import Set, Tupleimport numpy as np@dataclassclass NumpyGameOfLife: """Fixed-size grid using numpy for efficient computation.""" rows: int cols: int wrap: bool = True def __post_init__(self) -> None: self.grid: np.ndarray = np.zeros((self.rows, self.cols), dtype=np.int8) def set_alive(self, cells: Set[Tuple[int, int]]) -> None: """Initialize live cells.""" self.grid.fill(0) for r, c in cells: self.grid[r % self.rows, c % self.cols] = 1 def _count_neighbors(self) -> np.ndarray: """Count neighbors for all cells using array rolling.""" counts = np.zeros_like(self.grid, dtype=np.int8) for dr in (-1, 0, 1): for dc in (-1, 0, 1): if dr == 0 and dc == 0: continue shifted = np.roll(np.roll(self.grid, dr, axis=0), dc, axis=1) if not self.wrap: # Zero out wrapped edges for non-wrapping mode if dr == -1: shifted[-1, :] = 0 elif dr == 1: shifted[0, :] = 0 if dc == -1: shifted[:, -1] = 0 elif dc == 1: shifted[:, 0] = 0 counts += shifted return counts def step(self) -> None: """Advance simulation using vectorized operations.""" neighbor_count = self._count_neighbors() # Apply rules vectorized: # Survive: alive and 2-3 neighbors # Born: dead and exactly 3 neighbors survive = (self.grid == 1) & ((neighbor_count == 2) | (neighbor_count == 3)) born = (self.grid == 0) & (neighbor_count == 3) self.grid = (survive | born).astype(np.int8) def get_live_cells(self) -> Set[Tuple[int, int]]: """Return set of live cell coordinates.""" rows, cols = np.where(self.grid == 1) return set(zip(rows.tolist(), cols.tolist())) def count_alive(self) -> int: """Return total number of live cells.""" return int(np.sum(self.grid)) def display(self) -> str: """Return string representation of grid.""" chars = np.where(self.grid == 1, "#", ".") return "\n".join("".join(row) for row in chars)if __name__ == "__main__": # Test 1: Blinker oscillator game = NumpyGameOfLife(5, 5) game.set_alive({(2, 1), (2, 2), (2, 3)}) initial = game.get_live_cells() game.step() after_one = game.get_live_cells() game.step() after_two = game.get_live_cells() assert after_one == {(1, 2), (2, 2), (3, 2)}, f"Blinker step 1 failed: {after_one}" assert after_two == initial, f"Blinker should oscillate: {after_two}" print("Blinker test passed!") # Test 2: Block still life game2 = NumpyGameOfLife(4, 4) block = {(1, 1), (1, 2), (2, 1), (2, 2)} game2.set_alive(block) game2.step() assert game2.get_live_cells() == block, "Block should be stable" print("Block test passed!") # Test 3: Glider (4 generations) game3 = NumpyGameOfLife(10, 10) glider = {(1, 2), (2, 3), (3, 1), (3, 2), (3, 3)} game3.set_alive(glider) for _ in range(4): game3.step() expected = {(r + 1, c + 1) for r, c in glider} assert game3.get_live_cells() == expected, f"Glider failed: {game3.get_live_cells()}" print("Glider test passed!") # Test 4: Performance - large grid big_game = NumpyGameOfLife(100, 100) import random random.seed(42) cells = {(random.randint(0, 99), random.randint(0, 99)) for _ in range(1000)} big_game.set_alive(cells) for _ in range(10): big_game.step() print(f"Large grid test passed! Final population: {big_game.count_alive()}") print("All tests passed!")
Question 33 - Barrier Sync
Difficulty: 7 / 10
Approximate lines of code: 90 LoC
Tags: concurrency
Description
Concurrency Cheat Sheet
This problem involves concurrency. Here is a cheat sheet of relevant Python syntax. You may not need all of these; this list is not exhaustive or suggestive for the problem.
# threading.Threadt = threading.Thread(target=fn, args=(arg1,))t.start()t.join() # wait for completion# threading.Locklock = threading.Lock()lock.acquire()lock.release()with lock: # auto acquire/release pass# threading.Conditioncond = threading.Condition(lock) # or Condition() for internal lockwith cond: cond.wait() # release lock, block until notify, re-acquire cond.wait(timeout) # returns False on timeout cond.notify() # wake one waiter cond.notify_all() # wake all waiters# threading.Semaphoresem = threading.Semaphore(n) # n permitssem.acquire() # block if count=0sem.release() # increment countsem.acquire(timeout=1.0) # returns False on timeout# threading.Eventevent = threading.Event()event.set() # set flag, wake all waitersevent.clear() # reset flagevent.wait() # block until flag is setevent.wait(timeout) # returns False on timeoutevent.is_set() # check flag# threading.Barrier (built-in reference)barrier = threading.Barrier(n)barrier.wait() # block until n threads arrivebarrier.wait(timeout) # raises BrokenBarrierError on timeoutbarrier.reset() # reset to initial statebarrier.abort() # break barrier
A barrier is a synchronization primitive that blocks a group of threads until all of them have reached a certain point. Once all N threads call wait(), they are all released simultaneously. This is commonly used in parallel algorithms where each phase must complete before the next begins (e.g., parallel matrix operations, simulation steps).
The key implementation challenge is making the barrier reusable - after all threads pass through, the barrier should reset for the next round. A naive implementation fails when a fast thread re-enters the barrier before a slow thread has exited, causing the count to be wrong. The solution is “generation counting” - each barrier cycle has a generation number, and threads check their generation before proceeding.
Part A: Basic Barrier
Problem: Part A
Implement a barrier where wait() blocks until all N parties have arrived.
barrier = Barrier(parties=3)# Thread 1, 2, 3 all call wait()def worker(thread_id): print(f"Thread {thread_id} before barrier") barrier.wait() # Blocks until all 3 arrive print(f"Thread {thread_id} after barrier")# Output (before lines all complete, then after lines):# Thread 0 before barrier# Thread 1 before barrier# Thread 2 before barrier# Thread 0 after barrier # These three can be in any order# Thread 1 after barrier# Thread 2 after barrier
Make the barrier reusable for multiple rounds. Handle the “fast thread” race condition using generation counting.
barrier = Barrier(parties=2)def worker(): for round in range(3): # Do work... barrier.wait() # Must work correctly for all 3 rounds# Without generation counting, this breaks:# Round 1: Thread A arrives, Thread B arrives, both released# Round 2: Thread A (fast) re-enters while Thread B still exiting# count becomes 1, Thread A waits, Thread B exits# Thread B re-enters, count=2, releases both# But Thread A was waiting in round 2 with round 1's state!
Part C: Timeout and Callback
Problem: Part C
Add timeout support (returns False on timeout without corrupting barrier state) and a callback that executes exactly once when all threads arrive.
# Timeoutbarrier = Barrier(parties=2)result = barrier.wait(timeout=1.0) # Returns False after 1 second# Barrier is now "broken" - subsequent waits raise BrokenBarrierError# Callback (runs once, by the last arriving thread)def on_all_arrived(): print("All threads synchronized!")barrier = Barrier(parties=3, callback=on_all_arrived)# When third thread arrives, callback runs before any thread proceeds
Interview comments
Interview comments
Edge cases to probe:
What happens if a thread dies while waiting? (Barrier hangs or breaks)
Why is generation counting necessary? (Fast thread reentry race)
What if callback raises an exception?
Can you implement without condition variables (e.g., with semaphores)?
Timeout corrupting the count (must handle cleanly)
Holding the lock while waiting (blocks other threads from arriving)
Using notify() instead of notify_all() (only one thread wakes)
Code solutions
Code solutions
Solutions Overview
Solution 1 uses threading.Condition with generation counting - waiting threads check their generation to know when to proceed. Solution 2 uses a two-phase turnstile pattern with semaphores, ensuring clean separation between arrival and departure phases. Solution 3 uses threading.Event where each generation gets its own Event object that gets set and replaced. The key difference is the synchronization primitive used. Core techniques: generation counting, condition variables, semaphores, event signaling.
Solution 1: Condition variable approach
Uses threading.Condition with generation counting. Last thread to arrive resets count, increments generation, and calls notify_all(). Waiting threads check their generation to know if they should proceed.
"""Barrier Sync Solution 1: Condition Variable ApproachA reusable thread barrier using threading.Condition for synchronization.Supports timeout, callback, and graceful handling of thread failures."""import threadingfrom dataclasses import dataclass, fieldfrom typing import Callable, Optional@dataclassclass Barrier: """A reusable thread barrier with timeout and callback support.""" parties: int callback: Optional[Callable[[], None]] = None _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) _condition: threading.Condition = field(init=False, repr=False) _count: int = field(default=0, init=False) _generation: int = field(default=0, init=False) _broken: bool = field(default=False, init=False) def __post_init__(self) -> None: if self.parties <= 0: raise ValueError("parties must be positive") self._condition = threading.Condition(self._lock) def wait(self, timeout: Optional[float] = None) -> bool: """ Wait for all parties to reach the barrier. Returns True if barrier was successfully crossed, False on timeout. Raises BrokenBarrierError if barrier is broken. """ with self._condition: if self._broken: raise BrokenBarrierError("Barrier is broken") gen = self._generation self._count += 1 if self._count == self.parties: # Last thread to arrive: run callback and release everyone if self.callback: self.callback() self._count = 0 self._generation += 1 self._condition.notify_all() return True # Wait for other threads while gen == self._generation and not self._broken: if not self._condition.wait(timeout): # Timeout occurred - break the barrier self._broken = True self._condition.notify_all() return False if self._broken: raise BrokenBarrierError("Barrier is broken") return True def reset(self) -> None: """Reset the barrier to initial state.""" with self._condition: self._count = 0 self._generation += 1 self._broken = False self._condition.notify_all() def abort(self) -> None: """Break the barrier, causing all waiting threads to raise.""" with self._condition: self._broken = True self._condition.notify_all()class BrokenBarrierError(Exception): """Raised when a barrier is broken.""" pass
Solution 2: Semaphore-based approach
Two-phase turnstile pattern using semaphores. Phase 1: all threads arrive and wait on turnstile1. Phase 2: all threads depart through turnstile2. Ensures clean separation between rounds.
"""Barrier Sync Solution 2: Semaphore-Based ApproachA reusable thread barrier using two semaphores (turnstile pattern).Supports timeout, callback, and graceful handling of thread failures."""import threadingfrom dataclasses import dataclass, fieldfrom typing import Callable, Optional@dataclassclass Barrier: """A reusable barrier using the two-phase turnstile pattern.""" parties: int callback: Optional[Callable[[], None]] = None _mutex: threading.Lock = field(default_factory=threading.Lock, repr=False) _turnstile1: threading.Semaphore = field(init=False, repr=False) _turnstile2: threading.Semaphore = field(init=False, repr=False) _count: int = field(default=0, init=False) _broken: bool = field(default=False, init=False) def __post_init__(self) -> None: if self.parties <= 0: raise ValueError("parties must be positive") self._turnstile1 = threading.Semaphore(0) self._turnstile2 = threading.Semaphore(0) def wait(self, timeout: Optional[float] = None) -> bool: """Wait for all parties. Returns False on timeout.""" # Phase 1: Wait for all threads to arrive with self._mutex: if self._broken: raise BrokenBarrierError("Barrier is broken") self._count += 1 if self._count == self.parties: if self.callback: self.callback() for _ in range(self.parties): self._turnstile1.release() if not self._turnstile1.acquire(timeout=timeout): self._break_barrier() return False # Phase 2: Wait for all threads to pass turnstile1 with self._mutex: if self._broken: raise BrokenBarrierError("Barrier is broken") self._count -= 1 if self._count == 0: for _ in range(self.parties): self._turnstile2.release() if not self._turnstile2.acquire(timeout=timeout): self._break_barrier() return False if self._broken: raise BrokenBarrierError("Barrier is broken") return True def _break_barrier(self) -> None: """Mark barrier as broken and release waiting threads.""" with self._mutex: self._broken = True # Release all potentially waiting threads for _ in range(self.parties): self._turnstile1.release() self._turnstile2.release() def reset(self) -> None: """Reset barrier to initial state.""" with self._mutex: self._count = 0 self._broken = False self._turnstile1 = threading.Semaphore(0) self._turnstile2 = threading.Semaphore(0) def abort(self) -> None: """Break the barrier.""" self._break_barrier()class BrokenBarrierError(Exception): """Raised when a barrier is broken.""" pass
Solution 3: Event-based approach
Event-based approach. Each generation gets its own Event object. Last thread sets the event and creates a new one for the next round. Clean and minimal.
"""Barrier Sync Solution 3: Event-Based ApproachA reusable thread barrier using threading.Event for signaling.Uses generation counting for reusability. Clean, minimal implementation."""import threadingfrom dataclasses import dataclass, fieldfrom typing import Callable, Optional@dataclassclass Barrier: """A reusable barrier using Event signaling with generation tracking.""" parties: int callback: Optional[Callable[[], None]] = None _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) _event: threading.Event = field(default_factory=threading.Event, repr=False) _count: int = field(default=0, init=False) _generation: int = field(default=0, init=False) _broken: bool = field(default=False, init=False) def __post_init__(self) -> None: if self.parties <= 0: raise ValueError("parties must be positive") def wait(self, timeout: Optional[float] = None) -> bool: """ Block until all parties arrive. Returns True on success, False on timeout. Raises BrokenBarrierError if barrier breaks. """ with self._lock: if self._broken: raise BrokenBarrierError("Barrier is broken") my_gen = self._generation self._count += 1 if self._count == self.parties: # We're the last one - run callback, reset and signal everyone if self.callback: self.callback() self._count = 0 self._generation += 1 self._event.set() self._event = threading.Event() # New event for next round return True # Store current event before releasing lock current_event = self._event # Wait outside the lock signaled = current_event.wait(timeout) if not signaled: # Timeout - break the barrier with self._lock: if self._generation == my_gen: # Still our generation self._broken = True self._event.set() # Wake everyone return False with self._lock: if self._broken and self._generation == my_gen: raise BrokenBarrierError("Barrier is broken") return True def reset(self) -> None: """Reset barrier to clean state.""" with self._lock: self._count = 0 self._generation += 1 self._broken = False self._event.set() self._event = threading.Event() def abort(self) -> None: """Break the barrier, waking all waiters.""" with self._lock: self._broken = True self._event.set() @property def n_waiting(self) -> int: """Number of threads currently waiting.""" with self._lock: return self._countclass BrokenBarrierError(Exception): """Raised when a barrier is broken by timeout or abort.""" pass
Question 34 - Minesweeper
Difficulty: 3 / 10
Approximate lines of code: 80 LoC
Tags: game/simulation
Description
Minesweeper is a classic puzzle game where a grid contains hidden mines. Each non-mine cell displays the count of adjacent mines (0-8). The player reveals cells; revealing a mine loses the game, revealing all non-mine cells wins. The key data structure is a 2D grid of cells, each with state (hidden/revealed/flagged), whether it’s a mine, and its adjacent mine count.
The main algorithmic challenge is the cascade reveal: when revealing a cell with 0 adjacent mines, automatically reveal all neighbors recursively. This is a flood-fill algorithm (BFS or DFS) that stops when it hits cells with non-zero counts.
Part A: Board Setup with Adjacent Counts
Problem: Part A
Create a board with randomly placed mines and pre-compute adjacent mine counts for all cells. Each cell needs to know how many of its 8 neighbors contain mines.
game = Minesweeper(rows=5, cols=5, num_mines=3)# Internal state after creation:# - 3 cells have is_mine=True (random positions)# - Each non-mine cell has adjacent_mines computed# - All cells start with state=HIDDEN# For example, if mines are at (0,0), (2,2), (4,4):# Adjacent counts would be:# * 1 0 0 0# 1 2 1 1 0# 0 1 * 1 0# 0 1 1 2 1# 0 0 0 1 *
Part B: Reveal with Cascade
Problem: Part B
Implement reveal(row, col) that:
Returns False if a mine is hit
Reveals the cell
If adjacent_mines == 0, recursively reveals all 8 neighbors
Cascade stops at cells with non-zero counts (they’re revealed but don’t cascade further)
game = Minesweeper(rows=5, cols=5, num_mines=3)# Assume mine at (2,2) only, so many cells have count 0result = game.reveal(0, 0) # True (not a mine)# If (0,0) has count 0, cascade reveals neighbors# Cascade continues until hitting cells with count > 0# Internal state: cells with count 0 in connected region all revealed# Cells bordering the mine (1,1), (1,2), (1,3), etc. revealed and show countsresult = game.reveal(2, 2) # False (hit mine, game over)
Part C: Flags and Win Condition
Problem: Part C
Add flag toggling (mark suspected mines) and win detection (all non-mine cells revealed).
game = Minesweeper(rows=3, cols=3, num_mines=1)# Mine at (1,1)game.toggle_flag(1, 1) # Flag the suspected mine# Internal: cell (1,1) state = FLAGGEDgame.toggle_flag(1, 1) # Unflag# Internal: cell (1,1) state = HIDDEN# Reveal all non-mine cellsfor r in range(3): for c in range(3): if (r, c) != (1, 1): game.reveal(r, c)game.check_win() # True (all 8 non-mine cells revealed)
Interview comments
Interview comments
Edge cases to probe:
Revealing an already-revealed cell?
Flagging a revealed cell?
What if num_mines >= rows * cols?
Revealing a flagged cell (should it unflag first or be blocked)?
Common mistakes:
Cascade through numbered cells (should stop at them, just reveal)
Modifying board during iteration (use BFS queue or recursion stack)
Not handling first-click-is-mine (common game variant)
Including the cell itself in neighbor count
Code solutions
Code solutions
Solution 1 uses an object-oriented approach with a Cell dataclass containing state, is_mine, and adjacent_mines, using recursion for cascade reveal. Solution 2 takes a functional approach with separate 2D arrays for state and counts, keeping the data structures lightweight with string literals. Solution 3 uses sets and itertools for a compact implementation, with BFS (deque) for iterative flood-fill instead of recursion. The key difference is the data representation: nested objects vs parallel arrays vs set-based tracking. Core techniques: flood-fill (DFS/BFS), 2D grid traversal, neighbor enumeration.
Solution 1: Object-Oriented with Cell dataclass
An object-oriented approach with a Cell dataclass containing state, is_mine, and adjacent_mines. Uses recursion for cascade reveal.
"""Minesweeper - Solution 1: Object-Oriented with Cell dataclass."""from dataclasses import dataclass, fieldfrom enum import Enumfrom random import samplefrom typing import Iteratorclass CellState(Enum): HIDDEN = "hidden" REVEALED = "revealed" FLAGGED = "flagged"@dataclassclass Cell: is_mine: bool = False state: CellState = CellState.HIDDEN adjacent_mines: int = 0@dataclassclass Minesweeper: rows: int cols: int num_mines: int board: list[list[Cell]] = field(default_factory=list) def __post_init__(self) -> None: self.board = [[Cell() for _ in range(self.cols)] for _ in range(self.rows)] self._place_mines() self._calculate_adjacent() def _place_mines(self) -> None: positions = sample(range(self.rows * self.cols), self.num_mines) for pos in positions: r, c = divmod(pos, self.cols) self.board[r][c].is_mine = True def _neighbors(self, r: int, c: int) -> Iterator[tuple[int, int]]: for dr in (-1, 0, 1): for dc in (-1, 0, 1): if dr == 0 and dc == 0: continue nr, nc = r + dr, c + dc if 0 <= nr < self.rows and 0 <= nc < self.cols: yield nr, nc def _calculate_adjacent(self) -> None: for r in range(self.rows): for c in range(self.cols): if not self.board[r][c].is_mine: self.board[r][c].adjacent_mines = sum( 1 for nr, nc in self._neighbors(r, c) if self.board[nr][nc].is_mine ) def reveal(self, r: int, c: int) -> bool: """Reveal cell. Returns False if mine hit, True otherwise.""" cell = self.board[r][c] if cell.state != CellState.HIDDEN: return True cell.state = CellState.REVEALED if cell.is_mine: return False if cell.adjacent_mines == 0: for nr, nc in self._neighbors(r, c): self.reveal(nr, nc) return True def toggle_flag(self, r: int, c: int) -> None: cell = self.board[r][c] if cell.state == CellState.HIDDEN: cell.state = CellState.FLAGGED elif cell.state == CellState.FLAGGED: cell.state = CellState.HIDDEN def check_win(self) -> bool: for row in self.board: for cell in row: if not cell.is_mine and cell.state != CellState.REVEALED: return False return True
Solution 2: Functional approach with 2D arrays
A functional approach using 2D arrays for state and counts, with separate functions for game logic. Uses string literals for state to keep it lightweight.
"""Minesweeper - Solution 2: Functional approach with 2D arrays."""from dataclasses import dataclassfrom random import samplefrom typing import LiteralState = Literal["H", "R", "F"] # Hidden, Revealed, Flagged@dataclassclass Board: rows: int cols: int mines: set[tuple[int, int]] states: list[list[State]] counts: list[list[int]] @classmethod def create(cls, rows: int, cols: int, num_mines: int) -> "Board": positions = sample([(r, c) for r in range(rows) for c in range(cols)], num_mines) mines = set(positions) states: list[list[State]] = [["H"] * cols for _ in range(rows)] counts = [[0] * cols for _ in range(rows)] for r in range(rows): for c in range(cols): if (r, c) not in mines: counts[r][c] = sum( 1 for nr, nc in neighbors(r, c, rows, cols) if (nr, nc) in mines ) return cls(rows, cols, mines, states, counts)def neighbors(r: int, c: int, rows: int, cols: int) -> list[tuple[int, int]]: result = [] for dr in (-1, 0, 1): for dc in (-1, 0, 1): if dr == 0 and dc == 0: continue nr, nc = r + dr, c + dc if 0 <= nr < rows and 0 <= nc < cols: result.append((nr, nc)) return resultdef reveal(board: Board, r: int, c: int) -> bool: """Returns False if mine hit.""" if board.states[r][c] != "H": return True board.states[r][c] = "R" if (r, c) in board.mines: return False if board.counts[r][c] == 0: for nr, nc in neighbors(r, c, board.rows, board.cols): reveal(board, nr, nc) return Truedef toggle_flag(board: Board, r: int, c: int) -> None: if board.states[r][c] == "H": board.states[r][c] = "F" elif board.states[r][c] == "F": board.states[r][c] = "H"def check_win(board: Board) -> bool: for r in range(board.rows): for c in range(board.cols): if (r, c) not in board.mines and board.states[r][c] != "R": return False return Truedef check_loss(board: Board) -> bool: return any(board.states[r][c] == "R" for r, c in board.mines)
Solution 3: Using itertools and more compact logic
Uses itertools and sets for a more compact implementation. Cascade reveal uses BFS with a deque for iterative flood-fill instead of recursion.
"""Minesweeper - Solution 3: Using itertools and more compact logic."""from dataclasses import dataclass, fieldfrom itertools import productfrom random import samplefrom collections import deque@dataclassclass Minesweeper: rows: int cols: int mines: frozenset[tuple[int, int]] = field(default_factory=frozenset) revealed: set[tuple[int, int]] = field(default_factory=set) flagged: set[tuple[int, int]] = field(default_factory=set) counts: dict[tuple[int, int], int] = field(default_factory=dict) @classmethod def new_game(cls, rows: int, cols: int, num_mines: int) -> "Minesweeper": all_cells = list(product(range(rows), range(cols))) mines = frozenset(sample(all_cells, num_mines)) game = cls(rows=rows, cols=cols, mines=mines) game._compute_counts() return game def _neighbors(self, r: int, c: int) -> list[tuple[int, int]]: return [ (r + dr, c + dc) for dr, dc in product((-1, 0, 1), repeat=2) if (dr, dc) != (0, 0) and 0 <= r + dr < self.rows and 0 <= c + dc < self.cols ] def _compute_counts(self) -> None: for r, c in product(range(self.rows), range(self.cols)): if (r, c) not in self.mines: self.counts[(r, c)] = sum(1 for n in self._neighbors(r, c) if n in self.mines) def reveal(self, r: int, c: int) -> bool: """BFS-based reveal. Returns False if hit mine.""" if (r, c) in self.revealed or (r, c) in self.flagged: return True if (r, c) in self.mines: self.revealed.add((r, c)) return False queue = deque([(r, c)]) while queue: pos = queue.popleft() if pos in self.revealed: continue self.revealed.add(pos) if self.counts.get(pos, 0) == 0: for neighbor in self._neighbors(*pos): if neighbor not in self.revealed and neighbor not in self.mines: queue.append(neighbor) return True def toggle_flag(self, r: int, c: int) -> None: if (r, c) in self.revealed: return if (r, c) in self.flagged: self.flagged.remove((r, c)) else: self.flagged.add((r, c)) def is_won(self) -> bool: non_mine_cells = self.rows * self.cols - len(self.mines) return len(self.revealed) == non_mine_cells def is_lost(self) -> bool: return bool(self.revealed & self.mines)
Question 35 - Deadlock Detector
Difficulty: 10 / 10
Approximate lines of code: 90 LoC
Tags: concurrency, algorithms
Description
A deadlock occurs when threads are stuck waiting for each other in a cycle: Thread A holds Resource 1 and waits for Resource 2, while Thread B holds Resource 2 and waits for Resource 1. Neither can proceed. This is fundamentally a cycle detection problem in a directed graph called the “wait-for graph.”
The wait-for graph has threads as nodes. An edge from T1 to T2 means “T1 is waiting for a resource held by T2.” Internal state tracks two mappings: resource_owner: dict[resource_id, thread_id] (who holds each resource) and thread_waiting_for: dict[thread_id, resource_id] (what each blocked thread wants). From these, you can construct edges: if thread T waits for resource R, and R is owned by thread U, then there’s an edge T → U. A cycle in this graph means deadlock.
Part A: Resource Tracking
Problem: Part A
Implement acquire(thread_id, resource_id), wait(thread_id, resource_id), and release(thread_id, resource_id). Track who owns what and who is waiting for what.
Model wait-for relationships as a directed graph. Implement detect_deadlock() that returns the list of threads in the cycle, or None if no deadlock. Use DFS with three-color marking (WHITE=unvisited, GRAY=in current path, BLACK=done) to detect back-edges.
Handle the full lifecycle. When a thread acquires a resource it was waiting for, remove the wait edge. When a thread releases a resource, clear the ownership. Support re-entrant locks where a thread can re-acquire what it already owns (no-op, no self-loop).
What if a thread tries to acquire a resource it already owns?
What if wait() is called for a resource that’s not held by anyone?
What happens when a thread releases a resource others are waiting for?
How do you handle a thread that’s waiting for multiple resources?
Common mistakes:
Wrong edge direction (should be: waiter → owner, not owner → waiter)
Using simple visited set instead of three-color marking (misses back-edges in complex graphs)
Creating self-loop when thread re-acquires owned resource
Forgetting to clear wait state when thread finally acquires
Not removing edges when resources are released
Returning boolean instead of actual cycle participants
Code solutions
Code solutions
Solution 1 uses simple DFS-based cycle detection following the wait chain from thread to resource owner. Solution 2 builds an explicit graph with adjacency lists and uses three-color DFS (WHITE/GRAY/BLACK) for proper cycle detection in complex graphs. Solution 3 is event-driven, processing a stream of lock events and returning detailed DeadlockInfo with the full wait chain. The key difference is how they represent and traverse the wait-for graph. Core techniques: directed graph modeling, DFS cycle detection, three-color marking.
Solution 1: Simple DFS-based cycle detection
Simple DFS-based cycle detection. Stores resource_owner and thread_waiting_for mappings. Cycle detection follows the wait chain: if thread T waits for resource R owned by U, check if U eventually waits for something T owns.
"""Deadlock Detector - Solution 1: Simple DFS-based cycle detectionBuilds a wait-for graph and uses DFS to detect cycles."""from dataclasses import dataclass, fieldfrom typing import Optional@dataclassclass DeadlockDetector: """Tracks resource ownership and waits, detects deadlock cycles.""" resource_owner: dict[str, str] = field(default_factory=dict) # resource -> thread thread_waiting_for: dict[str, str] = field(default_factory=dict) # thread -> resource def acquire(self, thread_id: str, resource_id: str) -> bool: """Thread acquires a resource. Returns True if successful.""" if resource_id not in self.resource_owner: self.resource_owner[resource_id] = thread_id self.thread_waiting_for.pop(thread_id, None) return True if self.resource_owner[resource_id] == thread_id: return True # Already owns it return False def wait(self, thread_id: str, resource_id: str) -> None: """Thread starts waiting for a resource.""" if resource_id in self.resource_owner and self.resource_owner[resource_id] != thread_id: self.thread_waiting_for[thread_id] = resource_id def release(self, thread_id: str, resource_id: str) -> None: """Thread releases a resource.""" if self.resource_owner.get(resource_id) == thread_id: del self.resource_owner[resource_id] def detect_deadlock(self) -> Optional[list[str]]: """Returns list of threads in deadlock cycle, or None if no deadlock.""" for start_thread in self.thread_waiting_for: cycle = self._find_cycle_from(start_thread) if cycle: return cycle return None def _find_cycle_from(self, start: str) -> Optional[list[str]]: """DFS to find cycle starting from given thread.""" visited = set() path = [] current = start while current and current not in visited: visited.add(current) path.append(current) resource = self.thread_waiting_for.get(current) if not resource: return None owner = self.resource_owner.get(resource) if owner == start: path.append(start) return path current = owner return None
Solution 2: Explicit graph with three-color DFS
Explicit graph representation using adjacency lists. Uses three-color DFS (WHITE/GRAY/BLACK) for proper cycle detection that handles complex graphs with multiple components. Tracks which resources each thread holds for better edge management on release.
"""Deadlock Detector - Solution 2: Graph-based with explicit edge representationUses an adjacency list to represent the wait-for graph more explicitly."""from dataclasses import dataclass, fieldfrom enum import Enumfrom typing import Optionalclass Color(Enum): WHITE = 0 # Unvisited GRAY = 1 # In current DFS path BLACK = 2 # Fully processed@dataclassclass WaitForGraph: """Explicit graph structure for wait-for relationships.""" edges: dict[str, set[str]] = field(default_factory=dict) # thread -> threads it waits for resource_owner: dict[str, str] = field(default_factory=dict) thread_resources: dict[str, set[str]] = field(default_factory=dict) def add_thread(self, thread_id: str) -> None: if thread_id not in self.edges: self.edges[thread_id] = set() self.thread_resources[thread_id] = set() def acquire(self, thread_id: str, resource_id: str) -> bool: self.add_thread(thread_id) if resource_id not in self.resource_owner: self.resource_owner[resource_id] = thread_id self.thread_resources[thread_id].add(resource_id) return True return self.resource_owner[resource_id] == thread_id def request(self, thread_id: str, resource_id: str) -> None: """Thread requests a resource, creating wait-for edge if blocked.""" self.add_thread(thread_id) owner = self.resource_owner.get(resource_id) if owner and owner != thread_id: self.add_thread(owner) self.edges[thread_id].add(owner) def release(self, thread_id: str, resource_id: str) -> None: if self.resource_owner.get(resource_id) == thread_id: del self.resource_owner[resource_id] self.thread_resources[thread_id].discard(resource_id) # Remove edges pointing to this thread for this resource for t in self.edges: self.edges[t].discard(thread_id) def find_cycle(self) -> Optional[list[str]]: """Find a cycle using DFS with coloring.""" color = {t: Color.WHITE for t in self.edges} parent = {} for start in self.edges: if color[start] == Color.WHITE: cycle = self._dfs_cycle(start, color, parent) if cycle: return cycle return None def _dfs_cycle( self, node: str, color: dict[str, Color], parent: dict[str, str] ) -> Optional[list[str]]: color[node] = Color.GRAY for neighbor in self.edges.get(node, set()): if neighbor not in color: continue if color[neighbor] == Color.GRAY: # Found cycle - reconstruct it cycle = [neighbor] curr = node while curr != neighbor: cycle.append(curr) curr = parent.get(curr, neighbor) cycle.append(neighbor) return cycle[::-1] if color[neighbor] == Color.WHITE: parent[neighbor] = node result = self._dfs_cycle(neighbor, color, parent) if result: return result color[node] = Color.BLACK return None
Solution 3: Event-driven with detailed reporting
Event-driven design that processes a stream of lock events (ACQUIRE, RELEASE, WAIT). Returns detailed DeadlockInfo with the full wait chain showing which thread waits for which thread via which resource. Good for integration with monitoring systems.
"""Deadlock Detector - Solution 3: Event-driven with comprehensive reportingProcesses a stream of lock events and provides detailed deadlock information."""from dataclasses import dataclass, fieldfrom enum import Enum, autofrom typing import Optionalclass EventType(Enum): ACQUIRE = auto() RELEASE = auto() WAIT = auto()@dataclassclass LockEvent: event_type: EventType thread_id: str resource_id: str timestamp: int = 0@dataclassclass DeadlockInfo: threads: list[str] resources: list[str] wait_chain: list[tuple[str, str, str]] # (thread, waits_for, resource)@dataclassclass DeadlockDetector: """Event-driven deadlock detector with detailed reporting.""" holdings: dict[str, set[str]] = field(default_factory=dict) # thread -> resources held resource_owner: dict[str, str] = field(default_factory=dict) # resource -> owner thread wait_for_resource: dict[str, str] = field(default_factory=dict) # thread -> resource def process_event(self, event: LockEvent) -> Optional[DeadlockInfo]: """Process an event and return deadlock info if one is detected.""" if event.event_type == EventType.ACQUIRE: return self._handle_acquire(event.thread_id, event.resource_id) elif event.event_type == EventType.RELEASE: self._handle_release(event.thread_id, event.resource_id) elif event.event_type == EventType.WAIT: return self._handle_wait(event.thread_id, event.resource_id) return None def _handle_acquire(self, thread: str, resource: str) -> Optional[DeadlockInfo]: if resource in self.resource_owner: if self.resource_owner[resource] != thread: return None # Blocked - should use WAIT event self.resource_owner[resource] = thread self.holdings.setdefault(thread, set()).add(resource) self.wait_for_resource.pop(thread, None) return None def _handle_release(self, thread: str, resource: str) -> None: if self.resource_owner.get(resource) == thread: del self.resource_owner[resource] self.holdings.get(thread, set()).discard(resource) def _handle_wait(self, thread: str, resource: str) -> Optional[DeadlockInfo]: owner = self.resource_owner.get(resource) if not owner or owner == thread: return None self.wait_for_resource[thread] = resource return self._check_deadlock(thread) def _check_deadlock(self, start_thread: str) -> Optional[DeadlockInfo]: """Check if adding this wait creates a cycle.""" visited = set() path_threads = [] path_resources = [] current = start_thread while current and current not in visited: visited.add(current) path_threads.append(current) resource = self.wait_for_resource.get(current) if not resource: return None path_resources.append(resource) owner = self.resource_owner.get(resource) if owner == start_thread: # Found cycle wait_chain = [ (path_threads[i], path_threads[(i + 1) % len(path_threads)], path_resources[i]) for i in range(len(path_threads)) ] return DeadlockInfo( threads=path_threads, resources=path_resources, wait_chain=wait_chain, ) current = owner return None
Question 36 - Metrics Collector
Difficulty: 5 / 10
Approximate lines of code: 90 LoC
Tags: storage
Description
A metrics collector is the foundation of application monitoring systems like Prometheus, StatsD, or Datadog. It tracks three fundamental metric types: counters (monotonically increasing values like request counts), gauges (point-in-time values like CPU usage), and histograms (distributions of observations like latency). Each observation is timestamped to enable time-range queries.
The core data structures are: (1) timestamped value lists per metric name, and (2) aggregation logic that differs by metric type. Counters accumulate, gauges return the latest value, and histograms compute statistics over observations.
Note: The sortedcontainers library (SortedList) is available. For histogram percentile queries, using SortedList instead of bisect.insort gives O(log n) insertions.
Part A: Counters and Gauges
Problem: Part A
Implement counters that only increment and gauges that can be set to any value.
Support querying metrics within a specific time window. Filter observations by start/end timestamps.
collector = MetricsCollector()t1 = time.time()collector.observe("latency", 100)time.sleep(1)t2 = time.time()collector.observe("latency", 200)time.sleep(1)t3 = time.time()collector.observe("latency", 300)# Query specific time rangestats = collector.get_histogram("latency", start=t2, end=t3)# Returns stats for observations between t2 and t3 only# {"count": 2, "sum": 500, "min": 200, "max": 300, "avg": 250}# Get gauge value at specific timecollector.get_gauge("cpu", start=t1, end=t2) # Latest value in range
Interview comments
Interview comments
Edge cases to probe:
What does get_histogram return for a metric with zero observations?
What happens if you query a counter for a time range before any increments?
How do you handle avg when count is zero (division by zero)?
Should time range be inclusive or exclusive on boundaries?
Common mistakes:
Division by zero when computing average of empty histogram
Allowing negative increments on counters
Not storing timestamps, making time-range queries impossible
Returning 0 instead of None for missing gauges (0 is a valid gauge value)
O(n) linear scan for every query instead of using efficient data structures
Code solutions
Code solutions
Solution 1 uses a simple dictionary-based approach with timestamped entries for each metric type. Solution 2 takes an object-oriented approach with separate Counter, Gauge, and Histogram classes. Solution 3 uses time-bucketed storage with sorted insertion for efficient range queries and percentile calculations. The key difference is storage granularity: flat lists vs encapsulated classes vs time-bucketed structures. Core techniques: timestamped entries, list comprehension filtering, bisect for sorted insertion, defaultdict for auto-initialization.
Solution 1: Simple dictionary-based approach
Simple dictionary-based approach with timestamped entries. Each metric type stored in separate defaultdict. Time filtering done via list comprehension.
"""Metrics Collector - Solution 1: Simple In-Memory ImplementationUses dictionaries with timestamped entries for all metric types."""from dataclasses import dataclass, fieldfrom typing import Optionalfrom collections import defaultdictimport time@dataclassclass MetricEntry: value: float timestamp: float@dataclassclass MetricsCollector: counters: dict[str, list[MetricEntry]] = field(default_factory=lambda: defaultdict(list)) gauges: dict[str, list[MetricEntry]] = field(default_factory=lambda: defaultdict(list)) histograms: dict[str, list[MetricEntry]] = field(default_factory=lambda: defaultdict(list)) def _now(self) -> float: return time.time() def increment(self, name: str, value: float = 1.0) -> None: """Increment a counter by value (must be positive).""" if value < 0: raise ValueError("Counter increments must be non-negative") entries = self.counters[name] current = entries[-1].value if entries else 0.0 entries.append(MetricEntry(current + value, self._now())) def set_gauge(self, name: str, value: float) -> None: """Set a gauge to an arbitrary value.""" self.gauges[name].append(MetricEntry(value, self._now())) def observe(self, name: str, value: float) -> None: """Record an observation in a histogram.""" self.histograms[name].append(MetricEntry(value, self._now())) def get_counter(self, name: str, start: Optional[float] = None, end: Optional[float] = None) -> float: """Get current counter value, optionally filtered by time range.""" entries = self._filter_by_time(self.counters[name], start, end) return entries[-1].value if entries else 0.0 def get_gauge(self, name: str, start: Optional[float] = None, end: Optional[float] = None) -> Optional[float]: """Get most recent gauge value in time range.""" entries = self._filter_by_time(self.gauges[name], start, end) return entries[-1].value if entries else None def get_histogram(self, name: str, start: Optional[float] = None, end: Optional[float] = None) -> dict: """Get histogram statistics: count, sum, min, max, avg.""" entries = self._filter_by_time(self.histograms[name], start, end) if not entries: return {"count": 0, "sum": 0.0, "min": None, "max": None, "avg": None} values = [e.value for e in entries] return { "count": len(values), "sum": sum(values), "min": min(values), "max": max(values), "avg": sum(values) / len(values), } def _filter_by_time( self, entries: list[MetricEntry], start: Optional[float], end: Optional[float] ) -> list[MetricEntry]: result = entries if start is not None: result = [e for e in result if e.timestamp >= start] if end is not None: result = [e for e in result if e.timestamp <= end] return result
Solution 2: Object-oriented with metric classes
Object-oriented approach with separate Counter, Gauge, and Histogram classes. Each metric type encapsulates its own behavior and storage. A MetricsRegistry provides named access to metrics.
"""Metrics Collector - Solution 2: Separation of Concerns with Metric ClassesEach metric type is its own class with dedicated storage and behavior."""from dataclasses import dataclass, fieldfrom typing import Optional, Protocolfrom abc import ABC, abstractmethodimport time@dataclassclass TimestampedValue: value: float timestamp: float = field(default_factory=time.time)class Metric(ABC): @abstractmethod def query(self, start: Optional[float] = None, end: Optional[float] = None) -> any: pass@dataclassclass Counter(Metric): _values: list[TimestampedValue] = field(default_factory=list) _total: float = 0.0 def increment(self, amount: float = 1.0) -> None: if amount < 0: raise ValueError("Cannot decrement a counter") self._total += amount self._values.append(TimestampedValue(self._total)) def query(self, start: Optional[float] = None, end: Optional[float] = None) -> float: filtered = self._filter(start, end) return filtered[-1].value if filtered else 0.0 def _filter(self, start: Optional[float], end: Optional[float]) -> list[TimestampedValue]: return [v for v in self._values if (start is None or v.timestamp >= start) and (end is None or v.timestamp <= end)]@dataclassclass Gauge(Metric): _values: list[TimestampedValue] = field(default_factory=list) def set(self, value: float) -> None: self._values.append(TimestampedValue(value)) def query(self, start: Optional[float] = None, end: Optional[float] = None) -> Optional[float]: filtered = [v for v in self._values if (start is None or v.timestamp >= start) and (end is None or v.timestamp <= end)] return filtered[-1].value if filtered else None@dataclassclass Histogram(Metric): _observations: list[TimestampedValue] = field(default_factory=list) def observe(self, value: float) -> None: self._observations.append(TimestampedValue(value)) def query(self, start: Optional[float] = None, end: Optional[float] = None) -> dict: filtered = [v.value for v in self._observations if (start is None or v.timestamp >= start) and (end is None or v.timestamp <= end)] if not filtered: return {"count": 0, "sum": 0.0, "min": None, "max": None, "avg": None} return { "count": len(filtered), "sum": sum(filtered), "min": min(filtered), "max": max(filtered), "avg": sum(filtered) / len(filtered), }@dataclassclass MetricsRegistry: _counters: dict[str, Counter] = field(default_factory=dict) _gauges: dict[str, Gauge] = field(default_factory=dict) _histograms: dict[str, Histogram] = field(default_factory=dict) def counter(self, name: str) -> Counter: if name not in self._counters: self._counters[name] = Counter() return self._counters[name] def gauge(self, name: str) -> Gauge: if name not in self._gauges: self._gauges[name] = Gauge() return self._gauges[name] def histogram(self, name: str) -> Histogram: if name not in self._histograms: self._histograms[name] = Histogram() return self._histograms[name]
Solution 3: Time-bucketed storage with percentiles
Time-bucketed storage for efficient range queries. Uses sorted insertion (bisect) within buckets for percentile calculations. Supports percentile queries (p50, p90, p99) in addition to basic statistics.
"""Metrics Collector - Solution 3: Time-Bucketed Storage with PercentilesUses time buckets for efficient range queries and supports percentile calculations."""from dataclasses import dataclass, fieldfrom typing import Optionalfrom bisect import bisect_left, bisect_right, insortimport timeBUCKET_SIZE = 60.0 # 1 minute buckets@dataclassclass TimeBucket: start_time: float values: list[float] = field(default_factory=list) def add(self, value: float) -> None: insort(self.values, value) # Keep sorted for percentile queries@dataclassclass TimeSeriesStore: buckets: list[TimeBucket] = field(default_factory=list) _bucket_starts: list[float] = field(default_factory=list) def _get_bucket(self, timestamp: float) -> TimeBucket: bucket_start = (timestamp // BUCKET_SIZE) * BUCKET_SIZE idx = bisect_left(self._bucket_starts, bucket_start) if idx < len(self._bucket_starts) and self._bucket_starts[idx] == bucket_start: return self.buckets[idx] bucket = TimeBucket(bucket_start) self.buckets.insert(idx, bucket) self._bucket_starts.insert(idx, bucket_start) return bucket def add(self, value: float, timestamp: Optional[float] = None) -> None: ts = timestamp or time.time() self._get_bucket(ts).add(value) def get_values(self, start: Optional[float], end: Optional[float]) -> list[float]: result = [] for bucket in self.buckets: if start and bucket.start_time + BUCKET_SIZE < start: continue if end and bucket.start_time > end: break result.extend(bucket.values) return result@dataclassclass MetricsCollector: _counters: dict[str, float] = field(default_factory=dict) _counter_history: dict[str, TimeSeriesStore] = field(default_factory=dict) _gauges: dict[str, TimeSeriesStore] = field(default_factory=dict) _histograms: dict[str, TimeSeriesStore] = field(default_factory=dict) def increment(self, name: str, value: float = 1.0) -> None: if value < 0: raise ValueError("Counter increment must be non-negative") self._counters[name] = self._counters.get(name, 0.0) + value if name not in self._counter_history: self._counter_history[name] = TimeSeriesStore() self._counter_history[name].add(self._counters[name]) def set_gauge(self, name: str, value: float) -> None: if name not in self._gauges: self._gauges[name] = TimeSeriesStore() self._gauges[name].add(value) def observe(self, name: str, value: float) -> None: if name not in self._histograms: self._histograms[name] = TimeSeriesStore() self._histograms[name].add(value) def get_counter(self, name: str) -> float: return self._counters.get(name, 0.0) def get_gauge(self, name: str, start: Optional[float] = None, end: Optional[float] = None) -> Optional[float]: if name not in self._gauges: return None values = self._gauges[name].get_values(start, end) return values[-1] if values else None def get_histogram( self, name: str, start: Optional[float] = None, end: Optional[float] = None, percentiles: list[float] = None ) -> dict: if name not in self._histograms: return {"count": 0, "sum": 0.0, "min": None, "max": None, "avg": None, "percentiles": {}} values = self._histograms[name].get_values(start, end) if not values: return {"count": 0, "sum": 0.0, "min": None, "max": None, "avg": None, "percentiles": {}} sorted_vals = sorted(values) pcts = {} for p in (percentiles or []): idx = int(len(sorted_vals) * p / 100) pcts[p] = sorted_vals[min(idx, len(sorted_vals) - 1)] return { "count": len(values), "sum": sum(values), "min": min(values), "max": max(values), "avg": sum(values) / len(values), "percentiles": pcts, }
Question 37 - Packet Reassembler
Difficulty: 4 / 10
Approximate lines of code: 80 LoC
Tags: data-structures
Description
Network packets often arrive out of order due to routing differences, retransmissions, and varying network paths. A packet reassembler buffers incoming packets and reconstructs the original data stream by ordering packets by their sequence numbers. The core data structure is a dictionary or buffer mapping sequence numbers to packet data, combined with a pointer tracking the next expected sequence number. When contiguous packets are available starting from this pointer, they can be “flushed” to produce the reassembled output.
The key insight is that you need O(1) lookup by sequence number (dictionary) while also tracking what’s missing. Gap detection becomes important when you need to request retransmissions - you need to know which sequence numbers haven’t arrived yet between your expected position and the highest seen.
Part A: Basic Buffering and Reassembly
Problem: Part A
Implement a packet reassembler with receive(packet) and reassemble() methods. Packets have a sequence number and data payload. Buffer out-of-order packets until you can produce a contiguous stream.
Add get_missing_sequences() to return gaps between expected and highest received. Also handle duplicates - receive() should return False for duplicate packets.
Add timeout tracking so you can request retransmissions for missing packets that have been missing “too long”. Track when gaps are first detected and return timed-out sequences.
reassembler = PacketReassembler(timeout_seconds=5.0)reassembler.receive(Packet(seq_num=0, data=b"A"))reassembler.receive(Packet(seq_num=3, data=b"D")) # Gap detected at t=0# At t=3 secondsretransmit = reassembler.get_retransmit_requests()# Returns [] (not timed out yet)# At t=6 secondsretransmit = reassembler.get_retransmit_requests()# Returns [1, 2] (missing sequences that have timed out)
Interview comments
Interview comments
Edge cases to probe:
What happens if you receive the same packet twice?
What if packets arrive before sequence 0?
How do you handle window overflow (very large gaps)?
What if packet data is empty?
Common mistakes:
Not advancing expected_seq after reassembly
O(n) linear scan for missing sequences instead of using set difference
Storing already-processed packets (memory leak)
Not removing packets from buffer after reassembly
Code solutions
Code solutions
Solution 1 uses a simple dictionary mapping sequence numbers to packets, with linear iteration to find contiguous runs. Solution 2 uses a min-heap ordered by sequence number for efficient retrieval of the next expected packet, with streaming callback support. Solution 3 maintains explicit Gap objects to track missing ranges, which is more memory-efficient for sparse sequences. The key difference is how they track and detect gaps: dictionary lookup vs heap ordering vs explicit gap ranges. Core techniques: hash maps, min-heaps, interval tracking.
Solution 1: Simple Dictionary-Based Approach
A straightforward dictionary-based approach. Packets are stored by sequence number, and reassembly iterates from expected_seq until a gap is found. Gap detection scans from expected to max seen.
"""Packet Reassembler - Solution 1: Simple Dictionary-Based ApproachUses a dictionary to store packets by sequence number, with periodicreassembly checks."""from dataclasses import dataclassfrom typing import Optionalimport time@dataclassclass Packet: seq_num: int data: bytes timestamp: float = 0.0 def __post_init__(self) -> None: if self.timestamp == 0.0: self.timestamp = time.time()class PacketReassembler: def __init__(self, timeout_seconds: float = 5.0) -> None: self.packets: dict[int, Packet] = {} self.expected_seq: int = 0 self.timeout_seconds = timeout_seconds self.duplicates_received: int = 0 def receive(self, packet: Packet) -> bool: """Receive a packet. Returns True if new, False if duplicate.""" if packet.seq_num in self.packets: self.duplicates_received += 1 return False self.packets[packet.seq_num] = packet return True def get_missing_sequences(self) -> list[int]: """Return list of missing sequence numbers up to highest received.""" if not self.packets: return [] max_seq = max(self.packets.keys()) missing = [] for seq in range(self.expected_seq, max_seq): if seq not in self.packets: missing.append(seq) return missing def get_timed_out_sequences(self) -> list[int]: """Return missing sequences that have timed out.""" if not self.packets: return [] current_time = time.time() oldest_packet_time = min(p.timestamp for p in self.packets.values()) if current_time - oldest_packet_time < self.timeout_seconds: return [] return self.get_missing_sequences() def reassemble(self) -> Optional[bytes]: """Try to reassemble contiguous packets from expected_seq.""" if not self.packets: return None result = bytearray() while self.expected_seq in self.packets: result.extend(self.packets[self.expected_seq].data) del self.packets[self.expected_seq] self.expected_seq += 1 return bytes(result) if result else None
Solution 2: Heap-Based with Streaming Output
A heap-based streaming approach. Uses a min-heap ordered by sequence number for efficient retrieval of the next expected packet. Supports callbacks for real-time streaming output.
"""Packet Reassembler - Solution 2: Heap-Based with Streaming OutputUses a min-heap for efficient retrieval of the next expected packet.Supports streaming reassembly with callbacks."""from dataclasses import dataclass, fieldfrom typing import Callable, Optionalimport heapqimport time@dataclass(order=True)class Packet: seq_num: int data: bytes = field(compare=False) timestamp: float = field(compare=False, default_factory=time.time)class StreamingPacketReassembler: def __init__( self, on_data: Optional[Callable[[bytes], None]] = None, timeout_seconds: float = 5.0, ) -> None: self.heap: list[Packet] = [] self.seen_sequences: set[int] = set() self.expected_seq: int = 0 self.on_data = on_data self.timeout_seconds = timeout_seconds self.duplicate_count: int = 0 def receive(self, packet: Packet) -> bool: """Receive packet, auto-flush if contiguous. Returns False if dup.""" if packet.seq_num in self.seen_sequences: self.duplicate_count += 1 return False self.seen_sequences.add(packet.seq_num) heapq.heappush(self.heap, packet) self._try_flush() return True def _try_flush(self) -> None: """Flush contiguous packets to callback.""" while self.heap and self.heap[0].seq_num == self.expected_seq: packet = heapq.heappop(self.heap) if self.on_data: self.on_data(packet.data) self.expected_seq += 1 def get_missing_sequences(self) -> list[int]: """Return missing sequence numbers.""" if not self.heap: return [] max_seq = max(p.seq_num for p in self.heap) return [s for s in range(self.expected_seq, max_seq) if s not in self.seen_sequences] def request_retransmit(self) -> list[int]: """Return sequences needing retransmit based on timeout.""" if not self.heap: return [] now = time.time() oldest = min(p.timestamp for p in self.heap) if now - oldest >= self.timeout_seconds: return self.get_missing_sequences() return [] def get_buffered_data(self) -> bytes: """Get all buffered data in order (doesn't advance expected_seq).""" sorted_packets = sorted(self.heap, key=lambda p: p.seq_num) return b"".join(p.data for p in sorted_packets)
Solution 3: Sliding Window with Gap Tracking
A sliding window approach with explicit gap tracking. Maintains a list of Gap objects representing missing ranges, which is more memory-efficient for sparse sequences than tracking individual missing numbers.
"""Packet Reassembler - Solution 3: Sliding Window with Gap TrackingUses a sliding window buffer with explicit gap tracking for efficientmissing packet detection. More memory-efficient for sparse sequences."""from dataclasses import dataclassfrom typing import Optionalimport time@dataclassclass Packet: seq_num: int data: bytes timestamp: float = 0.0 def __post_init__(self) -> None: if self.timestamp == 0.0: self.timestamp = time.time()@dataclassclass Gap: start: int # inclusive end: int # exclusive first_detected: float = 0.0 def __post_init__(self) -> None: if self.first_detected == 0.0: self.first_detected = time.time()class SlidingWindowReassembler: def __init__(self, window_size: int = 1024, timeout_s: float = 5.0) -> None: self.window_size = window_size self.timeout_seconds = timeout_s self.buffer: dict[int, Packet] = {} self.gaps: list[Gap] = [] self.next_expected: int = 0 self.highest_seen: int = -1 self.duplicates: int = 0 def receive(self, packet: Packet) -> bool: """Receive a packet. Returns False if duplicate or out of window.""" seq = packet.seq_num # Reject if already processed or duplicate if seq < self.next_expected or seq in self.buffer: self.duplicates += 1 return False # Reject if too far ahead if seq >= self.next_expected + self.window_size: return False self.buffer[seq] = packet # Update gap tracking if seq > self.highest_seen + 1: self.gaps.append(Gap(start=self.highest_seen + 1, end=seq)) if seq > self.highest_seen: self.highest_seen = seq self._update_gaps(seq) return True def _update_gaps(self, filled_seq: int) -> None: """Remove or shrink gaps when a sequence is filled.""" new_gaps = [] for gap in self.gaps: if filled_seq < gap.start or filled_seq >= gap.end: new_gaps.append(gap) elif filled_seq == gap.start: if gap.start + 1 < gap.end: new_gaps.append(Gap(gap.start + 1, gap.end, gap.first_detected)) elif filled_seq == gap.end - 1: if gap.start < gap.end - 1: new_gaps.append(Gap(gap.start, gap.end - 1, gap.first_detected)) else: new_gaps.append(Gap(gap.start, filled_seq, gap.first_detected)) new_gaps.append(Gap(filled_seq + 1, gap.end, gap.first_detected)) self.gaps = new_gaps def reassemble(self) -> Optional[bytes]: """Reassemble contiguous packets from next_expected.""" result = bytearray() while self.next_expected in self.buffer: result.extend(self.buffer.pop(self.next_expected).data) self.next_expected += 1 # Clean up gaps below next_expected self.gaps = [g for g in self.gaps if g.end > self.next_expected] return bytes(result) if result else None def get_retransmit_requests(self) -> list[int]: """Return timed-out missing sequences needing retransmit.""" now = time.time() result = [] for gap in self.gaps: if now - gap.first_detected >= self.timeout_seconds: result.extend(range(gap.start, gap.end)) return [s for s in result if s >= self.next_expected]
Question 38 - Bloom Filter
Difficulty: 8 / 10
Approximate lines of code: 70 LoC
Tags: probabilistic, data-structures
Description
A Bloom filter is a space-efficient probabilistic data structure for set membership testing. Instead of storing actual elements, it uses a bit array and multiple hash functions. When you add an element, you hash it with k different functions and set the corresponding k bits to 1. To check membership, you verify all k bits are set. The key insight: false negatives are impossible (if an element was added, its bits are definitely set), but false positives can occur (bits might be set by other elements).
The tradeoff is simple: more bits and more hash functions reduce false positives, but use more memory and CPU. For n expected items and desired false positive rate p, optimal bit array size is m = -n*ln(p)/(ln(2)^2) and optimal hash count is k = (m/n)*ln(2).
Part A: Basic Structure
Problem: Part A
Implement a Bloom filter with add(item) and might_contain(item) methods. The filter needs a bit array of fixed size and k hash functions. Use salted hashing (prepend index to item before hashing) to simulate multiple hash functions from one.
bf = BloomFilter(size=1000, num_hashes=3)bf.add("apple")# Internal: hash_0("apple") = 142, hash_1("apple") = 567, hash_2("apple") = 891# bit_array: [..., 1@142, ..., 1@567, ..., 1@891, ...]bf.add("banana")# Sets bits at positions 234, 567, 789 (567 overlaps with "apple")bf.might_contain("apple") # True - bits 142, 567, 891 all setbf.might_contain("cherry") # Likely False - unless all its bit positions happen to be setbf.might_contain("grape") # Could be True (false positive) if its bits were set by other items
Part B: Optimal Sizing
Problem: Part B
Add static methods to calculate optimal parameters given expected items and desired false positive rate:
optimal_size(expected_items, fp_rate) returns bit array size
optimal_num_hashes(size, expected_items) returns number of hash functions
Add expected_false_positive_rate() that estimates the current FP probability based on items added. The formula is (1 - e^(-kn/m))^k where k=hash count, n=items added, m=size.
bf = BloomFilter(size=1000, num_hashes=3)for word in ["apple", "banana", "cherry"]: bf.add(word)fp_rate = bf.expected_false_positive_rate() # ~0.00027 for these parameters
Interview comments
Interview comments
Edge cases to probe:
What happens if you add the same element twice? (No effect - bits already set)
Can you remove elements? (Not from basic Bloom filter - need counting variant)
What if num_hashes > size? (Bits overlap more, higher FP rate)
Why not use Python’s built-in hash()? (Non-deterministic across runs)
Common mistakes:
Using a single hash function (defeats purpose - need k independent hashes)
Using Python’s hash() which isn’t stable across interpreter sessions
Forgetting that false positive rate increases as more items are added
Using list of bools instead of bit manipulation (8x memory waste)
Code solutions
Code solutions
Solutions Overview
Solution 1 uses salted hashing with a boolean list for simplicity. Solution 2 upgrades to double hashing (MD5+SHA1) with bytearray bit manipulation for memory efficiency. Solution 3 implements a counting Bloom filter with 4-bit counters and FNV-1a hashing to support deletions. The key difference is the tradeoff between simplicity, memory efficiency, and functionality (deletion support). Core techniques: hash functions, bit manipulation, probabilistic data structures.
Solution 1: Basic Implementation
Basic implementation using salted hashing (prepending index to create k hash functions) with a simple boolean list for the bit array. Straightforward but not memory-optimized.
"""Bloom Filter - Solution 1: Basic ImplementationUses built-in hash with salt values for multiple hash functions."""from dataclasses import dataclass, fieldimport math@dataclassclass BloomFilter: """Probabilistic set membership data structure.""" size: int # Number of bits in the filter num_hashes: int # Number of hash functions to use bit_array: list[bool] = field(default_factory=list, repr=False) _count: int = field(default=0, repr=False) def __post_init__(self) -> None: if self.size <= 0: raise ValueError("Size must be positive") if self.num_hashes <= 0: raise ValueError("Number of hash functions must be positive") self.bit_array = [False] * self.size def _get_hash_indices(self, item: str) -> list[int]: """Generate k hash indices for an item using salted hashing.""" indices = [] for i in range(self.num_hashes): # Use salt to create different hash functions salted = f"{i}:{item}" hash_value = hash(salted) % self.size indices.append(hash_value) return indices def add(self, item: str) -> None: """Add an element to the Bloom filter.""" for index in self._get_hash_indices(item): self.bit_array[index] = True self._count += 1 def might_contain(self, item: str) -> bool: """Check if element might be in the set. May have false positives.""" return all(self.bit_array[i] for i in self._get_hash_indices(item)) def expected_false_positive_rate(self) -> float: """Calculate expected false positive probability.""" # Formula: (1 - e^(-k*n/m))^k # k = num_hashes, n = items added, m = size if self._count == 0: return 0.0 exponent = -self.num_hashes * self._count / self.size return (1 - math.exp(exponent)) ** self.num_hashes @staticmethod def optimal_size(expected_items: int, fp_rate: float) -> int: """Calculate optimal bit array size for given items and FP rate.""" # m = -n * ln(p) / (ln(2)^2) return int(-expected_items * math.log(fp_rate) / (math.log(2) ** 2)) @staticmethod def optimal_num_hashes(size: int, expected_items: int) -> int: """Calculate optimal number of hash functions.""" # k = (m/n) * ln(2) return max(1, int((size / expected_items) * math.log(2)))
Solution 2: Double Hashing with Bytearray
Production-quality implementation using double hashing technique (h(i) = h1 + i*h2) with MD5 and SHA1. Uses bytearray for memory efficiency with bit manipulation for get/set operations.
"""Bloom Filter - Solution 2: Using mmh3 (MurmurHash3)Production-quality implementation with double hashing technique."""from dataclasses import dataclass, fieldimport mathimport hashlib@dataclassclass BloomFilter: """Bloom filter using double hashing for hash function generation.""" size: int num_hashes: int _bits: bytearray = field(default_factory=bytearray, repr=False) _count: int = field(default=0, repr=False) def __post_init__(self) -> None: if self.size <= 0 or self.num_hashes <= 0: raise ValueError("Size and num_hashes must be positive") # Use bytearray for memory efficiency self._bits = bytearray((self.size + 7) // 8) def _hash_pair(self, item: str) -> tuple[int, int]: """Generate two independent hash values using MD5 and SHA1.""" encoded = item.encode("utf-8") h1 = int(hashlib.md5(encoded).hexdigest(), 16) h2 = int(hashlib.sha1(encoded).hexdigest(), 16) return h1, h2 def _get_indices(self, item: str) -> list[int]: """Double hashing: h(i) = h1 + i*h2 mod m.""" h1, h2 = self._hash_pair(item) return [(h1 + i * h2) % self.size for i in range(self.num_hashes)] def _set_bit(self, index: int) -> None: """Set a bit in the bytearray.""" self._bits[index // 8] |= 1 << (index % 8) def _get_bit(self, index: int) -> bool: """Get a bit from the bytearray.""" return bool(self._bits[index // 8] & (1 << (index % 8))) def add(self, item: str) -> None: """Add an element to the filter.""" for idx in self._get_indices(item): self._set_bit(idx) self._count += 1 def might_contain(self, item: str) -> bool: """Check if element might exist (may have false positives).""" return all(self._get_bit(idx) for idx in self._get_indices(item)) def expected_fp_rate(self) -> float: """Expected false positive rate: (1 - e^(-kn/m))^k.""" if self._count == 0: return 0.0 p = 1 - math.exp(-self.num_hashes * self._count / self.size) return p ** self.num_hashes def fill_ratio(self) -> float: """Ratio of bits set to total bits.""" set_bits = sum(bin(byte).count("1") for byte in self._bits) return set_bits / self.size @classmethod def create_optimal(cls, expected_items: int, fp_rate: float) -> "BloomFilter": """Factory method to create optimally-sized Bloom filter.""" size = int(-expected_items * math.log(fp_rate) / (math.log(2) ** 2)) num_hashes = max(1, int((size / expected_items) * math.log(2))) return cls(size=size, num_hashes=num_hashes)
Solution 3: Counting Bloom Filter
Counting Bloom filter variant that supports deletions by using 4-bit counters instead of single bits. Uses FNV-1a hash algorithm. Increment counters on add, decrement on remove.
"""Bloom Filter - Solution 3: Counting Bloom FilterSupports deletions by using counters instead of bits."""from dataclasses import dataclass, fieldfrom typing import Anyimport mathimport struct@dataclassclass CountingBloomFilter: """Bloom filter variant that supports deletions using 4-bit counters.""" size: int num_hashes: int _counters: list[int] = field(default_factory=list, repr=False) _count: int = field(default=0, repr=False) MAX_COUNT: int = field(default=15, repr=False) # 4-bit counter max def __post_init__(self) -> None: if self.size <= 0 or self.num_hashes <= 0: raise ValueError("Size and num_hashes must be positive") self._counters = [0] * self.size def _hash(self, item: Any, seed: int) -> int: """Generate hash using FNV-1a algorithm with seed.""" data = f"{seed}:{item}".encode("utf-8") h = 2166136261 # FNV offset basis for byte in data: h ^= byte h = (h * 16777619) & 0xFFFFFFFF # FNV prime, 32-bit return h % self.size def _get_indices(self, item: Any) -> list[int]: """Get all hash indices for an item.""" return [self._hash(item, i) for i in range(self.num_hashes)] def add(self, item: Any) -> None: """Add element to filter. Increments counters.""" for idx in self._get_indices(item): if self._counters[idx] < self.MAX_COUNT: self._counters[idx] += 1 self._count += 1 def remove(self, item: Any) -> bool: """Remove element. Returns True if element was possibly present.""" if not self.might_contain(item): return False for idx in self._get_indices(item): if self._counters[idx] > 0: self._counters[idx] -= 1 self._count = max(0, self._count - 1) return True def might_contain(self, item: Any) -> bool: """Check membership. False positives possible, no false negatives.""" return all(self._counters[idx] > 0 for idx in self._get_indices(item)) def expected_fp_rate(self) -> float: """Calculate expected false positive probability.""" if self._count == 0: return 0.0 exp = -self.num_hashes * self._count / self.size return (1 - math.exp(exp)) ** self.num_hashes def memory_usage_bytes(self) -> int: """Estimate memory usage in bytes.""" # Python int overhead is significant; real impl would pack counters return self.size * 4 # Assuming packed 4-bit counters = 0.5 bytes each @classmethod def optimal(cls, n: int, fp: float) -> "CountingBloomFilter": """Create optimally-sized counting Bloom filter.""" m = int(-n * math.log(fp) / (math.log(2) ** 2)) k = max(1, int((m / n) * math.log(2))) return cls(size=m, num_hashes=k)
Question 39 - Cron Parser
Difficulty: 6 / 10
Approximate lines of code: 80 LoC
Tags: scheduling
Description
Cron is a time-based job scheduler in Unix-like systems. A cron expression consists of 5 fields: minute (0-59), hour (0-23), day of month (1-31), month (1-12), and day of week (0-6, where 0=Sunday). Each field supports wildcards (*), ranges (1-5), lists (1,3,5), and steps (*/15). The parser must convert these expressions into a set of valid values for each field, then compute when a job should next run.
Key implementation detail: Python’s datetime.weekday() returns Monday=0, but cron uses Sunday=0. You must convert: cron_weekday = (python_weekday + 1) % 7.
Note: The sortedcontainers library (SortedList) is available. Using SortedList for field values enables O(log n) lookup of the next valid value via bisect_left.
Part A: Parse Cron Expression
Problem: Part A
Parse a 5-field cron expression into sets of valid values for each field.
cron = CronExpression.parse("*/15 9-17 * * 1-5")# Internal state after parsing:# minute: {0, 15, 30, 45} - every 15 minutes# hour: {9, 10, 11, ..., 17} - 9am to 5pm# day: {1, 2, ..., 31} - every day# month: {1, 2, ..., 12} - every month# weekday: {1, 2, 3, 4, 5} - Monday through Fridaycron2 = CronExpression.parse("0 0 1 * *")# minute: {0}, hour: {0}, day: {1}, month: {1-12}, weekday: {0-6}# Runs at midnight on the 1st of each month
Syntax rules:
* = all valid values
5 = specific value
1-5 = range (inclusive)
1,3,5 = list
*/15 = every 15 starting from min
1-10/2 = every 2 in range 1-10 (i.e., 1,3,5,7,9)
Part B: Calculate Next Run Time
Problem: Part B
Given a starting datetime, find the next time the cron expression matches.
The naive approach increments by 1 minute until a match is found. A smarter approach jumps forward field by field (month, day, hour, minute) to avoid iterating through 500k+ minutes per year.
Part C: Validation and Combination Syntax
Problem: Part C
Add robust validation and support combined syntax like 1-10/2 (every 2nd value in range 1-10).
# Valid expressionscron = CronExpression.parse("0,30 9-17/2 * * *")# minute: {0, 30}# hour: {9, 11, 13, 15, 17} - every 2 hours from 9-17# Invalid expressions should raise clear errorstry: CronExpression.parse("60 * * * *") # minute out of rangeexcept ValueError as e: print(e) # "Out of range in minute: [60] (valid: 0-59)"try: CronExpression.parse("* * * *") # wrong number of fieldsexcept ValueError as e: print(e) # "Need 5 fields, got 4"
Interview comments
Interview comments
Edge cases to probe:
What happens with 0 0 31 2 * (Feb 31st doesn’t exist)?
How do you handle weekday mapping (Python Mon=0, cron Sun=0)?
What if no valid time exists within a year?
Does 5/10 mean “every 10 starting at 5” or something else?
Infinite loop in next_run if expression can never match (e.g., Feb 30th)
Off-by-one in range parsing: 1-5 should include 5
Not handling step on wildcard: */15 starts at field minimum, not 0
Mutating datetime during iteration without advancing past current time
Code solutions
Code solutions
Solution 1 uses an object-oriented approach with separate CronField and CronExpression classes, using naive minute-by-minute iteration for next_run. Solution 2 takes a functional approach with smarter field-by-field jumping to avoid iterating through hundreds of thousands of minutes. Solution 3 focuses on validation with comprehensive error messages, regex-based input checking, and immutable frozen dataclasses. These vary in their approach to finding the next run time (brute force vs smart jumping) and how much validation/error reporting they provide.
Core techniques: field parsing with ranges/steps/wildcards, datetime arithmetic, weekday conversion (Python Mon=0 to cron Sun=0).
Solution 1: Object-oriented with field classes
Object-oriented approach with CronField and CronExpression classes. Uses naive minute-by-minute iteration for next_run. Clean separation between parsing and matching logic.
"""Cron Parser - Solution 1: Object-oriented with field classes."""from dataclasses import dataclassfrom datetime import datetime, timedeltafrom typing import Setdef cron_weekday(dt: datetime) -> int: """Convert Python weekday (Mon=0) to cron weekday (Sun=0).""" return (dt.weekday() + 1) % 7@dataclassclass CronField: """Represents a single cron field with its valid values.""" values: Set[int] @classmethod def parse(cls, expr: str, min_val: int, max_val: int) -> "CronField": values: Set[int] = set() for part in expr.split(","): if "/" in part: base, step = part.split("/") step_val = int(step) start = min_val if base == "*" else int(base.split("-")[0]) end = max_val if base == "*" or "-" not in base else int(base.split("-")[1]) values.update(range(start, end + 1, step_val)) elif "-" in part: start, end = map(int, part.split("-")) values.update(range(start, end + 1)) elif part == "*": values.update(range(min_val, max_val + 1)) else: values.add(int(part)) if not all(min_val <= v <= max_val for v in values): raise ValueError(f"Values must be in range [{min_val}, {max_val}]") return cls(values)@dataclassclass CronExpression: """Parsed cron expression with five fields.""" minute: CronField hour: CronField day: CronField month: CronField weekday: CronField @classmethod def parse(cls, expr: str) -> "CronExpression": parts = expr.strip().split() if len(parts) != 5: raise ValueError("Cron expression must have exactly 5 fields") return cls( CronField.parse(parts[0], 0, 59), CronField.parse(parts[1], 0, 23), CronField.parse(parts[2], 1, 31), CronField.parse(parts[3], 1, 12), CronField.parse(parts[4], 0, 6), ) def matches(self, dt: datetime) -> bool: return (dt.minute in self.minute.values and dt.hour in self.hour.values and dt.day in self.day.values and dt.month in self.month.values and cron_weekday(dt) in self.weekday.values) def next_run(self, from_time: datetime) -> datetime: dt = from_time.replace(second=0, microsecond=0) for _ in range(366 * 24 * 60): if self.matches(dt): return dt dt += timedelta(minutes=1) raise ValueError("No valid run time found within one year")
Solution 2: Functional approach with smart next_run
Functional approach with smarter next_run that jumps by field instead of minute-by-minute. Uses sorted lists for efficient “find next valid value” operations.
"""Cron Parser - Solution 2: Functional approach with smart next_run."""from dataclasses import dataclassfrom datetime import datetime, timedeltafrom typing import List, TupleFIELD_RANGES: List[Tuple[int, int]] = [(0, 59), (0, 23), (1, 31), (1, 12), (0, 6)]def cron_weekday(dt: datetime) -> int: """Convert Python weekday (Mon=0) to cron weekday (Sun=0).""" return (dt.weekday() + 1) % 7def parse_field(expr: str, min_val: int, max_val: int) -> List[int]: """Parse a single cron field into a sorted list of valid values.""" result = set() for part in expr.split(","): if "/" in part: base, step = part.split("/", 1) step_int = int(step) if base == "*": start, end = min_val, max_val elif "-" in base: start, end = map(int, base.split("-", 1)) else: start, end = int(base), max_val result.update(range(start, end + 1, step_int)) elif "-" in part: start, end = map(int, part.split("-", 1)) result.update(range(start, end + 1)) elif part == "*": result.update(range(min_val, max_val + 1)) else: result.add(int(part)) if any(v < min_val or v > max_val for v in result): raise ValueError(f"Values out of range [{min_val}, {max_val}]") return sorted(result)@dataclassclass Cron: """Cron expression with precomputed valid values for each field.""" minutes: List[int] hours: List[int] days: List[int] months: List[int] weekdays: List[int] @classmethod def parse(cls, expr: str) -> "Cron": parts = expr.strip().split() if len(parts) != 5: raise ValueError(f"Expected 5 fields, got {len(parts)}") fields = [parse_field(p, r[0], r[1]) for p, r in zip(parts, FIELD_RANGES)] return cls(*fields) def _find_next(self, vals: List[int], current: int) -> Tuple[int, bool]: """Find next valid value >= current. Returns (value, wrapped).""" for v in vals: if v >= current: return v, False return vals[0], True def next_run(self, from_time: datetime) -> datetime: dt = from_time.replace(second=0, microsecond=0) + timedelta(minutes=1) for _ in range(366 * 24 * 60): if dt.month not in self.months: next_month, wrapped = self._find_next(self.months, dt.month) dt = datetime(dt.year + (1 if wrapped else 0), next_month, 1, 0, 0) continue if dt.day not in self.days or cron_weekday(dt) not in self.weekdays: dt = (dt + timedelta(days=1)).replace(hour=0, minute=0) continue if dt.hour not in self.hours: next_hour, wrapped = self._find_next(self.hours, dt.hour) if wrapped: dt = (dt + timedelta(days=1)).replace(hour=0, minute=0) else: dt = dt.replace(hour=next_hour, minute=0) continue if dt.minute not in self.minutes: next_min, wrapped = self._find_next(self.minutes, dt.minute) dt = dt.replace(minute=0, hour=dt.hour + 1) if wrapped else dt.replace(minute=next_min) continue return dt raise ValueError("No valid time found")
Solution 3: Validation-focused with comprehensive error handling
Validation-focused approach with comprehensive error messages, regex-based input validation, and immutable frozen dataclass. Includes next_n_runs() utility method.
"""Cron Parser - Solution 3: Validation-focused with comprehensive error handling."""from dataclasses import dataclass, fieldfrom datetime import datetime, timedeltafrom typing import FrozenSet, Listimport reFIELD_SPECS = [("minute", 0, 59), ("hour", 0, 23), ("day", 1, 31), ("month", 1, 12), ("weekday", 0, 6)]def cron_weekday(dt: datetime) -> int: """Convert Python weekday (Mon=0) to cron weekday (Sun=0).""" return (dt.weekday() + 1) % 7def validate_and_parse(expr: str, name: str, lo: int, hi: int) -> FrozenSet[int]: """Parse field with detailed validation errors.""" if not re.match(r'^[\d,\-\*/]+$', expr): raise ValueError(f"Invalid characters in {name}: {expr}") values = set() for token in expr.split(","): step = 1 if "/" in token: token, step_str = token.rsplit("/", 1) if not step_str.isdigit() or int(step_str) == 0: raise ValueError(f"Invalid step in {name}: {step_str}") step = int(step_str) if token == "*": values.update(range(lo, hi + 1, step)) elif "-" in token: parts = token.split("-") if len(parts) != 2 or not all(p.isdigit() for p in parts): raise ValueError(f"Invalid range in {name}: {token}") start, end = int(parts[0]), int(parts[1]) if start > end: raise ValueError(f"Range start > end in {name}: {start}-{end}") values.update(range(start, end + 1, step)) else: if not token.isdigit(): raise ValueError(f"Invalid value in {name}: {token}") values.add(int(token)) out_of_range = [v for v in values if v < lo or v > hi] if out_of_range: raise ValueError(f"Out of range in {name}: {out_of_range} (valid: {lo}-{hi})") return frozenset(values)@dataclass(frozen=True)class CronSchedule: """Immutable cron schedule with validation.""" minute: FrozenSet[int] = field(default_factory=frozenset) hour: FrozenSet[int] = field(default_factory=frozenset) day: FrozenSet[int] = field(default_factory=frozenset) month: FrozenSet[int] = field(default_factory=frozenset) weekday: FrozenSet[int] = field(default_factory=frozenset) raw: str = "" @classmethod def from_string(cls, expr: str) -> "CronSchedule": parts = expr.strip().split() if len(parts) != 5: raise ValueError(f"Need 5 fields, got {len(parts)}: '{expr}'") parsed = {name: validate_and_parse(parts[i], name, lo, hi) for i, (name, lo, hi) in enumerate(FIELD_SPECS)} return cls(**parsed, raw=expr) def is_valid_time(self, dt: datetime) -> bool: return (dt.minute in self.minute and dt.hour in self.hour and dt.day in self.day and dt.month in self.month and cron_weekday(dt) in self.weekday) def next_run_after(self, start: datetime) -> datetime: current = start.replace(second=0, microsecond=0) + timedelta(minutes=1) for _ in range(366 * 24 * 60): if self.is_valid_time(current): return current current += timedelta(minutes=1) raise RuntimeError(f"No match within year for: {self.raw}") def next_n_runs(self, start: datetime, n: int) -> List[datetime]: runs, current = [], start for _ in range(n): current = self.next_run_after(current) runs.append(current) return runs
Question 40 - Library System
Difficulty: 1 / 10
Approximate lines of code: 100 LoC
Tags: storage
Description
A library management system tracks books, users, and loans. Core entities are books (with unique IDs, titles, and status), users, and transactions. The key data structures are dictionaries mapping book_id to Book objects, with each Book containing its current status (available/borrowed/reserved), the current borrower, due date, and a waitlist queue.
The interesting complexity comes from the waitlist: when a borrowed book is returned, it should become reserved for the first person on the waitlist rather than generally available. Late fees require tracking due dates and computing days overdue.
Part A: Checkout and Return
Problem: Part A
Implement add_book(), checkout(book_id, user_id), and return_book(book_id). A book can only be checked out if it’s available. Return should update status and clear the borrower.
lib = Library()lib.add_book("B001", "The Great Gatsby")# Internal state after add:# books = {"B001": Book(book_id="B001", title="The Great Gatsby",# status=AVAILABLE, borrower_id=None,# due_date=None, waitlist=[])}lib.checkout("B001", "U001") # Returns True# State: status=BORROWED, borrower_id="U001", due_date=<14 days from now>lib.checkout("B001", "U002") # Returns False - already borrowedlib.return_book("B001")# State: status=AVAILABLE, borrower_id=None, due_date=None
Part B: Reservations and Waitlist
Problem: Part B
Add reserve(book_id, user_id) that adds users to a waitlist. When a book is returned, it should be reserved for the first waitlisted user rather than becoming generally available.
lib.checkout("B001", "U001")lib.reserve("B001", "U002")lib.reserve("B001", "U003")# State: waitlist=["U002", "U003"]lib.reserve("B001", "U001") # Returns False - can't reserve what you havelib.reserve("B001", "U002") # Returns False - already on waitlistlib.return_book("B001")# State: status=RESERVED, waitlist=["U003"]# Book is now reserved for U002, not available to general publiclib.checkout("B001", "U003") # Returns False - U002 is first in queuelib.checkout("B001", "U002") # Returns True - reserved user can checkout# State: status=BORROWED, borrower_id="U002", waitlist=[]
Part C: Late Fees
Problem: Part C
Track due dates and calculate late fees on return. The fee is typically a daily rate times days overdue.
lib = Library(loan_days=14, daily_late_fee=0.25)lib.add_book("B001", "Clean Code")# Checkout on Jan 1, due Jan 15lib.checkout("B001", "U001", now=datetime(2024, 1, 1))# Return on Jan 21 (6 days late)fee = lib.return_book("B001", now=datetime(2024, 1, 21))assert fee == 6 * 0.25 # $1.50# Return on time returns 0lib.checkout("B001", "U002", now=datetime(2024, 1, 22))fee = lib.return_book("B001", now=datetime(2024, 2, 1)) # 10 days, within 14assert fee == 0.0
Interview comments
Interview comments
Edge cases to probe:
What happens if a user tries to reserve a book they already have checked out?
What if someone on the waitlist never picks up their reserved book?
How do you handle a user trying to checkout a book they’re first on the waitlist for?
What happens if return_book is called on a book that’s not borrowed?
Common mistakes:
Allowing a user to reserve a book they already have
Not removing user from waitlist when they checkout
Forgetting to change status to RESERVED (not AVAILABLE) when returning a waitlisted book
Using floating point for money (should use Decimal or cents)
Not handling the case where checkout is called on a reserved book by the wrong user
Code solutions
Code solutions
Solution 1 is a straightforward OOP approach with Book and Library dataclasses, using an enum for book status and dictionaries for storage. Solution 2 adds an event-driven design with transaction history logging, enabling user history queries and auditing, plus injectable timestamps for testing. Solution 3 takes a functional approach with immutable (frozen) dataclasses and pure functions that return new state, enabling easy testing and undo/redo. The key differences are mutability (mutable objects vs immutable state) and whether operations are logged. Core techniques: state machines for book status, queue management for waitlists, datetime arithmetic for late fees.
Solution 1: Simple OOP
Simple OOP approach with Book and Library dataclasses. Uses an enum for book status. Stores everything in dictionaries. Straightforward implementation of all rules.
"""Library System - Solution 1: Simple OOP with dictionaries for storage."""from dataclasses import dataclass, fieldfrom datetime import datetime, timedeltafrom enum import Enumfrom typing import Optionalclass BookStatus(Enum): AVAILABLE = "available" BORROWED = "borrowed" RESERVED = "reserved"@dataclassclass Book: book_id: str title: str status: BookStatus = BookStatus.AVAILABLE borrower_id: Optional[str] = None due_date: Optional[datetime] = None waitlist: list[str] = field(default_factory=list)@dataclassclass Library: books: dict[str, Book] = field(default_factory=dict) loan_days: int = 14 daily_late_fee: float = 0.25 def add_book(self, book_id: str, title: str) -> Book: book = Book(book_id=book_id, title=title) self.books[book_id] = book return book def checkout(self, book_id: str, user_id: str) -> bool: book = self.books.get(book_id) if not book or book.status == BookStatus.BORROWED: return False if book.status == BookStatus.RESERVED and book.waitlist[0] != user_id: return False book.status = BookStatus.BORROWED book.borrower_id = user_id book.due_date = datetime.now() + timedelta(days=self.loan_days) if user_id in book.waitlist: book.waitlist.remove(user_id) return True def return_book(self, book_id: str) -> float: book = self.books.get(book_id) if not book or book.status != BookStatus.BORROWED: return 0.0 late_fee = self.calculate_late_fee(book) book.borrower_id = None book.due_date = None if book.waitlist: book.status = BookStatus.RESERVED else: book.status = BookStatus.AVAILABLE return late_fee def reserve(self, book_id: str, user_id: str) -> bool: book = self.books.get(book_id) if not book or user_id in book.waitlist or book.borrower_id == user_id: return False book.waitlist.append(user_id) if book.status == BookStatus.AVAILABLE: book.status = BookStatus.RESERVED return True def calculate_late_fee(self, book: Book) -> float: if not book.due_date or datetime.now() <= book.due_date: return 0.0 days_late = (datetime.now() - book.due_date).days return days_late * self.daily_late_feeif __name__ == "__main__": lib = Library() lib.add_book("B001", "The Great Gatsby") lib.add_book("B002", "1984") assert lib.checkout("B001", "U001") == True assert lib.books["B001"].status == BookStatus.BORROWED assert lib.checkout("B001", "U002") == False # Already borrowed assert lib.reserve("B001", "U002") == True assert lib.books["B001"].waitlist == ["U002"] lib.return_book("B001") assert lib.books["B001"].status == BookStatus.RESERVED assert lib.checkout("B001", "U002") == True # Reserved user can checkout print("Solution 1: All tests passed!")
Solution 2: Event-driven with Transaction History
Event-driven approach with transaction history. Logs all operations as Transaction objects with timestamps. Enables querying user history and auditing. Accepts optional now parameter for testing.
"""Library System - Solution 2: Event-driven with transaction history."""from dataclasses import dataclass, fieldfrom datetime import datetime, timedeltafrom enum import Enumfrom typing import Optionalclass EventType(Enum): CHECKOUT = "checkout" RETURN = "return" RESERVE = "reserve" CANCEL_RESERVATION = "cancel_reservation"@dataclassclass Transaction: event_type: EventType book_id: str user_id: str timestamp: datetime late_fee: float = 0.0@dataclassclass Book: book_id: str title: str borrower_id: Optional[str] = None due_date: Optional[datetime] = None waitlist: list[str] = field(default_factory=list) @property def is_available(self) -> bool: return self.borrower_id is None and not self.waitlist @property def is_borrowed(self) -> bool: return self.borrower_id is not None@dataclassclass LibrarySystem: books: dict[str, Book] = field(default_factory=dict) transactions: list[Transaction] = field(default_factory=list) loan_days: int = 14 daily_fee: float = 0.25 def add_book(self, book_id: str, title: str) -> None: self.books[book_id] = Book(book_id=book_id, title=title) def checkout(self, book_id: str, user_id: str, now: Optional[datetime] = None) -> bool: now = now or datetime.now() book = self.books.get(book_id) if not book or book.is_borrowed: return False if book.waitlist and book.waitlist[0] != user_id: return False book.borrower_id = user_id book.due_date = now + timedelta(days=self.loan_days) if user_id in book.waitlist: book.waitlist.remove(user_id) self._log(EventType.CHECKOUT, book_id, user_id, now) return True def return_book(self, book_id: str, now: Optional[datetime] = None) -> float: now = now or datetime.now() book = self.books.get(book_id) if not book or not book.is_borrowed: return 0.0 user_id = book.borrower_id late_fee = max(0, (now - book.due_date).days) * self.daily_fee if book.due_date else 0.0 book.borrower_id = None book.due_date = None self._log(EventType.RETURN, book_id, user_id, now, late_fee) return late_fee def reserve(self, book_id: str, user_id: str, now: Optional[datetime] = None) -> bool: now = now or datetime.now() book = self.books.get(book_id) if not book or user_id in book.waitlist or book.borrower_id == user_id: return False book.waitlist.append(user_id) self._log(EventType.RESERVE, book_id, user_id, now) return True def get_user_history(self, user_id: str) -> list[Transaction]: return [t for t in self.transactions if t.user_id == user_id] def _log(self, event: EventType, book_id: str, user_id: str, ts: datetime, fee: float = 0.0) -> None: self.transactions.append(Transaction(event, book_id, user_id, ts, fee))if __name__ == "__main__": lib = LibrarySystem() lib.add_book("B001", "Clean Code") now = datetime(2024, 1, 1) assert lib.checkout("B001", "U001", now) == True assert lib.books["B001"].is_borrowed == True lib.reserve("B001", "U002", now) assert lib.books["B001"].waitlist == ["U002"] late_return = now + timedelta(days=20) fee = lib.return_book("B001", late_return) assert fee == 6 * 0.25 # 6 days late assert lib.checkout("B001", "U002", late_return) == True history = lib.get_user_history("U001") assert len(history) == 2 # checkout + return print("Solution 2: All tests passed!")
Solution 3: Functional Approach with Immutable State
Functional approach with immutable state. Uses frozen dataclasses and pure functions that return new state. Enables easy testing and undo/redo functionality. No side effects.
"""Library System - Solution 3: Functional approach with immutable state."""from dataclasses import dataclass, replacefrom datetime import datetime, timedeltafrom typing import Optional@dataclass(frozen=True)class BookState: book_id: str title: str borrower_id: Optional[str] = None due_date: Optional[datetime] = None waitlist: tuple[str, ...] = ()@dataclass(frozen=True)class LibraryState: books: tuple[BookState, ...] = () loan_days: int = 14 daily_fee: float = 0.25def add_book(state: LibraryState, book_id: str, title: str) -> LibraryState: new_book = BookState(book_id=book_id, title=title) return replace(state, books=state.books + (new_book,))def find_book(state: LibraryState, book_id: str) -> Optional[BookState]: return next((b for b in state.books if b.book_id == book_id), None)def update_book(state: LibraryState, updated: BookState) -> LibraryState: books = tuple(updated if b.book_id == updated.book_id else b for b in state.books) return replace(state, books=books)def checkout(state: LibraryState, book_id: str, user_id: str, now: datetime) -> tuple[LibraryState, bool]: book = find_book(state, book_id) if not book or book.borrower_id is not None: return state, False if book.waitlist and book.waitlist[0] != user_id: return state, False waitlist = tuple(u for u in book.waitlist if u != user_id) updated = replace(book, borrower_id=user_id, due_date=now + timedelta(days=state.loan_days), waitlist=waitlist) return update_book(state, updated), Truedef return_book(state: LibraryState, book_id: str, now: datetime) -> tuple[LibraryState, float]: book = find_book(state, book_id) if not book or book.borrower_id is None: return state, 0.0 fee = 0.0 if book.due_date and now > book.due_date: fee = (now - book.due_date).days * state.daily_fee updated = replace(book, borrower_id=None, due_date=None) return update_book(state, updated), feedef reserve(state: LibraryState, book_id: str, user_id: str) -> tuple[LibraryState, bool]: book = find_book(state, book_id) if not book or user_id in book.waitlist or book.borrower_id == user_id: return state, False updated = replace(book, waitlist=book.waitlist + (user_id,)) return update_book(state, updated), Truedef get_available_books(state: LibraryState) -> list[BookState]: return [b for b in state.books if b.borrower_id is None and not b.waitlist]if __name__ == "__main__": now = datetime(2024, 1, 1) state = LibraryState() state = add_book(state, "B001", "Design Patterns") state = add_book(state, "B002", "Refactoring") state, ok = checkout(state, "B001", "U001", now) assert ok == True assert find_book(state, "B001").borrower_id == "U001" state, ok = reserve(state, "B001", "U002") assert ok == True assert find_book(state, "B001").waitlist == ("U002",) late = now + timedelta(days=21) state, fee = return_book(state, "B001", late) assert fee == 7 * 0.25 state, ok = checkout(state, "B001", "U002", late) assert ok == True available = get_available_books(state) assert len(available) == 1 assert available[0].book_id == "B002" print("Solution 3: All tests passed!")
Question 41 - Write-Ahead Log
Difficulty: 10 / 10
Approximate lines of code: 100 LoC
Tags: storage, distributed-systems
Description
A Write-Ahead Log (WAL) provides durability for databases and storage systems. The core principle: before making any change to data, write a description of the change to a sequential log file and ensure it’s on disk (fsync). If the system crashes, replay the log to recover. This is faster than writing data directly because sequential writes are cheaper than random writes.
Key components: (1) log entries with sequence numbers for ordering, (2) fsync calls to guarantee durability (not just flush), (3) checksums to detect partial/corrupted writes, and (4) checkpointing to truncate old entries after they’re applied to the main data store.
Part A: Append and Replay
Problem: Part A
Implement append(entry) that writes to the log with proper durability and replay() that reads back all entries. Each entry needs a sequence number for ordering.
Critical: flush() is NOT enough for durability. You must call os.fsync(fd) to guarantee data reaches disk.
def append(self, op, key, value): entry = LogEntry(self._sequence, op, key, value) self._sequence += 1 self._file.write(json.dumps(entry) + "\n") self._file.flush() # Push to OS buffer os.fsync(self._file.fileno()) # Force to disk - THIS IS CRITICAL return entry
Part B: Crash Recovery
Problem: Part B
Handle partial writes from crashes. If the system crashes mid-write, the last entry may be corrupted or incomplete. Use checksums or length-prefixing to detect and skip bad entries during replay.
# Log file after crash mid-write:# {"sequence_num": 0, "op": "SET", "key": "x", "value": 10, "crc": 12345}# {"sequence_num": 1, "op": "SET", "key": "y", "value": 20, "crc": 67890}# {"sequence_num": 2, "op": "SE <- CORRUPTED (partial write)def replay(self, apply_fn): for line in self._file: try: entry = parse_entry(line) if entry.crc != compute_crc(entry): break # Stop at first corrupted entry apply_fn(entry) except (JSONDecodeError, ValueError): break # Stop at first corrupted entry
Operations must be idempotent - replaying the same entry twice should have the same effect as once:
# Idempotent: SET x = 10 (replay twice, x is still 10)# NOT idempotent: INCREMENT x (replay twice, x increases twice!)# For non-idempotent ops, track which LSNs have been applied
Part C: Checkpointing
Problem: Part C
After entries are applied to the main data store, they’re no longer needed for recovery. Implement checkpoint() that marks a point where all previous entries can be safely truncated.
wal = WriteAheadLog("log.wal")wal.append("SET", "x", 10) # LSN 0wal.append("SET", "y", 20) # LSN 1wal.append("SET", "z", 30) # LSN 2# Apply entries 0-1 to main data store...# Now entries 0-1 are durable in main storewal.checkpoint() # Truncates log, entries 0-2 removedwal.append("SET", "w", 40) # LSN 0 (sequence resets)# On recovery, only replay from LSN 0 (the new entry)
Critical: don’t truncate BEFORE checkpointing! If you truncate and then crash before the main data store is synced, you lose data.
def checkpoint(self): # Main data store MUST be durable before calling this self._file.close() self._file = open(self.log_path, "w+") # Truncate self._sequence = 0
Interview comments
Interview comments
Edge cases to probe:
What’s the difference between flush() and fsync()? (flush pushes to OS, fsync guarantees disk write)
What if power fails mid-fsync? (File system guarantees atomicity of blocks, but entry might be partial)
When is it safe to truncate the log? (Only after main data store is synced to disk)
How would you handle concurrent appends? (Single writer, or use file locks)
Common mistakes:
Using flush() without fsync() (data still in OS buffer, not durable)
No sequence numbers (replay order becomes ambiguous)
Blocking on every write when batching could improve throughput
Truncating before checkpoint (crash loses data)
Ignoring partial writes (last entry may be corrupted)
No checksums (silent corruption undetected)
Code solutions
Code solutions
Solution 1 is a simple JSON-based WAL using a single file with fsync on every write for durability. Solution 2 uses multiple segment files with automatic rotation when size limits are reached, supporting partial replay from a specific LSN. Solution 3 uses a binary format with struct packing and CRC32 checksums, plus transaction markers (BEGIN/COMMIT/ROLLBACK). These differ in storage format (JSON vs binary), file organization (single vs segmented), and corruption detection (parse errors vs checksums). Core techniques: fsync for durability, sequential file I/O, checksums, log compaction.
Solution 1: Simple JSON-based WAL
Simple JSON-based WAL with file-based storage. Entries are JSON lines with sequence numbers. Includes fsync on every write for durability. Basic checkpoint truncates the entire log.
"""Write-Ahead Log - Solution 1: Simple File-Based WALBasic implementation using a single log file with JSON entries."""from dataclasses import dataclass, fieldfrom typing import Any, Callableimport jsonimport os@dataclassclass LogEntry: """A single entry in the write-ahead log.""" sequence_num: int operation: str key: str value: Any@dataclassclass WriteAheadLog: """Simple file-based write-ahead log implementation.""" log_path: str sync_on_write: bool = True _sequence: int = field(default=0, init=False) _file: Any = field(default=None, init=False) def __post_init__(self) -> None: self._file = open(self.log_path, "a+") self._sequence = self._count_existing_entries() def _count_existing_entries(self) -> int: """Count entries to determine next sequence number.""" self._file.seek(0) count = sum(1 for line in self._file if line.strip()) self._file.seek(0, 2) # Seek to end return count def append(self, operation: str, key: str, value: Any) -> LogEntry: """Append an entry to the log before applying changes.""" entry = LogEntry(self._sequence, operation, key, value) self._sequence += 1 line = json.dumps(vars(entry)) + "\n" self._file.write(line) if self.sync_on_write: self._file.flush() os.fsync(self._file.fileno()) return entry def replay(self, apply_fn: Callable[[LogEntry], None]) -> int: """Replay all log entries for recovery.""" self._file.seek(0) count = 0 for line in self._file: if line.strip(): data = json.loads(line) entry = LogEntry(**data) apply_fn(entry) count += 1 self._file.seek(0, 2) return count def checkpoint(self) -> None: """Truncate log after successful checkpoint.""" self._file.close() self._file = open(self.log_path, "w+") self._sequence = 0 def close(self) -> None: """Close the log file.""" if self._file: self._file.close()
Solution 2: Segmented WAL with compaction
Segmented WAL with multiple files. When a segment reaches max size, rotates to a new file. Supports partial replay from a specific LSN. Compaction removes segments fully before a checkpoint LSN. More realistic for production systems that need to bound log size.
"""Write-Ahead Log - Solution 2: Segment-Based WALUses multiple segment files with compaction support."""from dataclasses import dataclass, fieldfrom typing import Any, Callable, Iteratorfrom pathlib import Pathimport jsonimport os@dataclassclass LogEntry: """A single WAL entry with LSN (Log Sequence Number).""" lsn: int op: str key: str value: Any checksum: int = 0 def __post_init__(self) -> None: if self.checksum == 0: self.checksum = hash((self.lsn, self.op, self.key, str(self.value)))@dataclassclass SegmentedWAL: """Segmented WAL with configurable segment size and compaction.""" directory: Path max_segment_size: int = 1024 * 1024 # 1MB default _current_lsn: int = field(default=0, init=False) _current_segment: int = field(default=0, init=False) _current_file: Any = field(default=None, init=False) _segment_size: int = field(default=0, init=False) def __post_init__(self) -> None: self.directory = Path(self.directory) self.directory.mkdir(parents=True, exist_ok=True) self._recover_state() def _segment_path(self, seg_num: int) -> Path: return self.directory / f"segment_{seg_num:08d}.wal" def _recover_state(self) -> None: """Recover LSN and segment info from existing files.""" segments = sorted(self.directory.glob("segment_*.wal")) if segments: self._current_segment = int(segments[-1].stem.split("_")[1]) for entry in self._read_segment(segments[-1]): self._current_lsn = entry.lsn + 1 self._open_segment() def _open_segment(self) -> None: if self._current_file: self._current_file.close() path = self._segment_path(self._current_segment) self._current_file = open(path, "a+") self._segment_size = path.stat().st_size if path.exists() else 0 def _read_segment(self, path: Path) -> Iterator[LogEntry]: with open(path, "r") as f: for line in f: if line.strip(): yield LogEntry(**json.loads(line)) def append(self, op: str, key: str, value: Any, fsync: bool = True) -> LogEntry: """Append entry with optional fsync for durability.""" entry = LogEntry(self._current_lsn, op, key, value) self._current_lsn += 1 data = json.dumps(vars(entry)) + "\n" if self._segment_size + len(data) > self.max_segment_size: self._rotate_segment() self._current_file.write(data) self._segment_size += len(data) if fsync: self._current_file.flush() os.fsync(self._current_file.fileno()) return entry def _rotate_segment(self) -> None: self._current_segment += 1 self._open_segment() def replay(self, apply_fn: Callable[[LogEntry], None], from_lsn: int = 0) -> int: """Replay entries starting from given LSN.""" count = 0 for seg_path in sorted(self.directory.glob("segment_*.wal")): for entry in self._read_segment(seg_path): if entry.lsn >= from_lsn: apply_fn(entry) count += 1 return count def compact(self, checkpoint_lsn: int) -> None: """Remove segments fully before checkpoint LSN.""" self._current_file.close() for seg_path in sorted(self.directory.glob("segment_*.wal")): max_lsn = max((e.lsn for e in self._read_segment(seg_path)), default=-1) if max_lsn < checkpoint_lsn: seg_path.unlink() self._open_segment() def close(self) -> None: if self._current_file: self._current_file.close()
Solution 3: Binary format with CRC checksums
Binary format WAL with CRC32 checksums for corruption detection. Uses struct packing for efficient storage. Supports transaction markers (BEGIN/COMMIT/ROLLBACK). Low-level file descriptor operations for precise control over durability.
A Snowflake-style ID generator produces globally unique 64-bit IDs without coordination between machines. Twitter invented this pattern to generate tweet IDs at scale. The key insight is packing three components into 64 bits: a timestamp (for ordering), a machine ID (for uniqueness across nodes), and a sequence number (for uniqueness within the same millisecond). The standard layout uses 41 bits for timestamp, 10 bits for machine ID (1024 machines), and 12 bits for sequence (4096 IDs per millisecond per machine).
Internal state consists of: the last timestamp seen, the current sequence number within that millisecond, and the machine ID. When generating an ID, if the current time equals the last timestamp, increment the sequence; if sequence overflows, wait for the next millisecond. If time has moved forward, reset sequence to 0.
Part A: Basic ID Generation
Problem: Part A
Implement a thread-safe ID generator with generate() -> int. Each ID must be unique and roughly time-ordered. Pack timestamp, machine_id, and sequence into a 64-bit integer using bit shifting.
Sequence overflow: If 4096 IDs are generated in one millisecond, wait for the next millisecond before continuing.
Clock skew: If the system clock moves backwards (NTP correction, VM migration), either raise an error or wait for time to catch up.
gen = SnowflakeIDGenerator(machine_id=1)# Simulate generating 4096 IDs in same millisecondfor i in range(4096): gen.generate()# State: sequence=4095# Next generate() must wait for next millisecondid_next = gen.generate() # Blocks until timestamp advances# State: last_timestamp=1704067200001, sequence=0# Clock moves backwards# Option 1: Raise RuntimeError("Clock moved backwards by Xms")# Option 2: Wait until time >= last_timestamp
Part C: ID Decoding and Coordination
Problem: Part C
Add parse_id(id) -> {timestamp_ms, machine_id, sequence} to decode components from an ID. Discuss how machines would coordinate to get unique machine IDs (ZooKeeper, database sequence, configuration).
gen = SnowflakeIDGenerator(machine_id=7)id1 = gen.generate()parsed = gen.parse_id(id1)# Returns: {"timestamp_ms": 1704067200000, "machine_id": 7, "sequence": 0}# Decoding uses bit masks and shifts:sequence = id1 & 0xFFF # Low 12 bitsmachine_id = (id1 >> 12) & 0x3FF # Next 10 bitstimestamp = (id1 >> 22) + EPOCH # High 41 bits + epoch
Interview comments
Interview comments
Edge cases to probe:
What happens if generate() is called from multiple threads simultaneously?
What if the machine ID is out of range (negative or > 1023)?
What if the system clock jumps forward by years (timestamp overflow)?
How would you handle a datacenter-level ID collision if machine IDs are misconfigured?
Common mistakes:
Not using a lock (race condition between reading timestamp and incrementing sequence)
Forgetting to wait when sequence overflows (generates duplicate IDs)
Using floating point for timestamps (precision loss)
Allowing negative machine IDs or IDs > max value
Not handling clock skew at all (silent ID collision)
Code solutions
Code solutions
Solution 1 provides a basic thread-safe implementation with dataclass configuration and RuntimeError on clock skew. Solution 2 adds clock skew tolerance by waiting instead of failing, plus ID parsing and injectable time functions for testing. Solution 3 introduces batch ID generation for efficiency and an immutable SnowflakeID dataclass with comparison operators. The key differences are in how they handle clock skew (fail vs wait) and whether they support batch operations. Core techniques: bit shifting/masking, thread synchronization with locks, busy-waiting for sequence overflow.
Solution 1: Basic Implementation
Basic implementation with all core functionality. Uses a dataclass for config, validates machine ID, and handles sequence overflow by busy-waiting. Raises RuntimeError on clock skew.
"""Snowflake-style ID Generator - Solution 1: Basic Implementation64-bit ID structure: - 1 bit: unused (sign bit) - 41 bits: timestamp (milliseconds since epoch) - 10 bits: machine ID (0-1023) - 12 bits: sequence number (0-4095)"""import timeimport threadingfrom dataclasses import dataclass@dataclassclass SnowflakeConfig: epoch: int = 1704067200000 # 2024-01-01 00:00:00 UTC in milliseconds machine_id_bits: int = 10 sequence_bits: int = 12 timestamp_bits: int = 41class SnowflakeIDGenerator: def __init__(self, machine_id: int, config: SnowflakeConfig | None = None): self.config = config or SnowflakeConfig() self._validate_machine_id(machine_id) self.machine_id = machine_id self.sequence = 0 self.last_timestamp = -1 self.lock = threading.Lock() # Precompute masks and shifts self.max_sequence = (1 << self.config.sequence_bits) - 1 self.machine_id_shift = self.config.sequence_bits self.timestamp_shift = self.config.sequence_bits + self.config.machine_id_bits def _validate_machine_id(self, machine_id: int) -> None: max_machine_id = (1 << self.config.machine_id_bits) - 1 if not 0 <= machine_id <= max_machine_id: raise ValueError(f"Machine ID must be between 0 and {max_machine_id}") def _current_timestamp(self) -> int: return int(time.time() * 1000) def _wait_for_next_millisecond(self, last_ts: int) -> int: ts = self._current_timestamp() while ts <= last_ts: ts = self._current_timestamp() return ts def generate(self) -> int: with self.lock: timestamp = self._current_timestamp() # Handle clock skew (clock moved backwards) if timestamp < self.last_timestamp: raise RuntimeError( f"Clock moved backwards by {self.last_timestamp - timestamp}ms" ) if timestamp == self.last_timestamp: self.sequence = (self.sequence + 1) & self.max_sequence if self.sequence == 0: # Sequence overflow, wait for next millisecond timestamp = self._wait_for_next_millisecond(self.last_timestamp) else: self.sequence = 0 self.last_timestamp = timestamp relative_ts = timestamp - self.config.epoch return ( (relative_ts << self.timestamp_shift) | (self.machine_id << self.machine_id_shift) | self.sequence )if __name__ == "__main__": gen = SnowflakeIDGenerator(machine_id=1) # Test uniqueness ids = [gen.generate() for _ in range(1000)] assert len(ids) == len(set(ids)), "IDs must be unique" # Test ordering assert ids == sorted(ids), "IDs must be ordered" # Test ID structure sample_id = ids[0] assert sample_id > 0, "ID must be positive" assert sample_id.bit_length() <= 64, "ID must fit in 64 bits" print(f"Generated {len(ids)} unique, ordered IDs") print(f"Sample ID: {sample_id} (binary length: {sample_id.bit_length()} bits)") print("All tests passed!")
Solution 2: With Clock Skew Tolerance
Adds clock skew tolerance by waiting instead of failing (up to a configurable limit). Includes a parse_id method to decode IDs back to components. Accepts an injectable time function for testing.
"""Snowflake-style ID Generator - Solution 2: With Clock Skew ToleranceHandles clock skew by waiting or borrowing from the future within tolerance."""import timeimport threadingfrom dataclasses import dataclass, fieldfrom typing import Callable@dataclassclass IDGeneratorConfig: epoch_ms: int = 1704067200000 # 2024-01-01 UTC machine_bits: int = 10 sequence_bits: int = 12 max_clock_skew_ms: int = 5000 # Tolerate up to 5 seconds of clock drift@dataclassclass GeneratorState: last_timestamp: int = -1 sequence: int = 0class RobustIDGenerator: def __init__( self, machine_id: int, config: IDGeneratorConfig | None = None, time_fn: Callable[[], int] | None = None, ): self.config = config or IDGeneratorConfig() self.machine_id = self._validate_and_get_machine_id(machine_id) self.state = GeneratorState() self.lock = threading.Lock() self._time_fn = time_fn or (lambda: int(time.time() * 1000)) self.max_sequence = (1 << self.config.sequence_bits) - 1 self.max_machine_id = (1 << self.config.machine_bits) - 1 self.machine_shift = self.config.sequence_bits self.timestamp_shift = self.config.sequence_bits + self.config.machine_bits def _validate_and_get_machine_id(self, machine_id: int) -> int: max_id = (1 << self.config.machine_bits) - 1 if not 0 <= machine_id <= max_id: raise ValueError(f"Machine ID must be 0-{max_id}, got {machine_id}") return machine_id def _handle_clock_skew(self, current_ts: int, last_ts: int) -> int: skew = last_ts - current_ts if skew > self.config.max_clock_skew_ms: raise RuntimeError(f"Clock skew {skew}ms exceeds tolerance") # Wait for clock to catch up time.sleep(skew / 1000.0) return self._time_fn() def generate(self) -> int: with self.lock: current_ts = self._time_fn() if current_ts < self.state.last_timestamp: current_ts = self._handle_clock_skew( current_ts, self.state.last_timestamp ) if current_ts == self.state.last_timestamp: self.state.sequence = (self.state.sequence + 1) & self.max_sequence if self.state.sequence == 0: while current_ts <= self.state.last_timestamp: current_ts = self._time_fn() else: self.state.sequence = 0 self.state.last_timestamp = current_ts relative_ts = current_ts - self.config.epoch_ms return ( (relative_ts << self.timestamp_shift) | (self.machine_id << self.machine_shift) | self.state.sequence ) def parse_id(self, id_value: int) -> dict: """Decode an ID back into its components.""" sequence = id_value & self.max_sequence machine = (id_value >> self.machine_shift) & self.max_machine_id timestamp = (id_value >> self.timestamp_shift) + self.config.epoch_ms return {"timestamp_ms": timestamp, "machine_id": machine, "sequence": sequence}if __name__ == "__main__": gen = RobustIDGenerator(machine_id=42) ids = [gen.generate() for _ in range(500)] assert len(ids) == len(set(ids)), "IDs must be unique" assert all(ids[i] < ids[i + 1] for i in range(len(ids) - 1)), "IDs must be ordered" parsed = gen.parse_id(ids[0]) assert parsed["machine_id"] == 42, "Machine ID must match" print(f"Parsed first ID: {parsed}") # Test with custom time function (simulated clock) fake_time = [1704067200000] gen2 = RobustIDGenerator(machine_id=1, time_fn=lambda: fake_time[0]) id1 = gen2.generate() id2 = gen2.generate() assert id1 < id2, "Sequence should increment" print(f"Generated {len(ids)} unique ordered IDs") print("All tests passed!")
Solution 3: Async-Ready with Batch Generation
Adds batch ID generation for efficiency (generate multiple IDs while holding the lock once). Includes an immutable SnowflakeID dataclass that supports comparison operations.
"""Snowflake-style ID Generator - Solution 3: Async-Ready with Batch GenerationFeatures batch ID generation and async-compatible design."""import timeimport threadingfrom dataclasses import dataclassfrom typing import Iterator@dataclass(frozen=True)class SnowflakeID: """Immutable ID with parsed components.""" value: int timestamp_ms: int machine_id: int sequence: int def __int__(self) -> int: return self.value def __lt__(self, other: "SnowflakeID") -> bool: return self.value < other.valueclass BatchIDGenerator: EPOCH = 1704067200000 # 2024-01-01 UTC MACHINE_BITS = 10 SEQUENCE_BITS = 12 MAX_SEQUENCE = (1 << SEQUENCE_BITS) - 1 MAX_MACHINE_ID = (1 << MACHINE_BITS) - 1 MACHINE_SHIFT = SEQUENCE_BITS TIMESTAMP_SHIFT = SEQUENCE_BITS + MACHINE_BITS def __init__(self, machine_id: int): if not 0 <= machine_id <= self.MAX_MACHINE_ID: raise ValueError(f"Machine ID must be 0-{self.MAX_MACHINE_ID}") self.machine_id = machine_id self._sequence = 0 self._last_ts = -1 self._lock = threading.Lock() def _get_time_ms(self) -> int: return int(time.time() * 1000) def _next_timestamp(self) -> int: ts = self._get_time_ms() while ts <= self._last_ts: ts = self._get_time_ms() return ts def _build_id(self, timestamp: int, sequence: int) -> int: relative_ts = timestamp - self.EPOCH return ( (relative_ts << self.TIMESTAMP_SHIFT) | (self.machine_id << self.MACHINE_SHIFT) | sequence ) def generate(self) -> int: with self._lock: ts = self._get_time_ms() if ts < self._last_ts: raise RuntimeError(f"Clock regression: {self._last_ts - ts}ms") if ts == self._last_ts: self._sequence = (self._sequence + 1) & self.MAX_SEQUENCE if self._sequence == 0: ts = self._next_timestamp() else: self._sequence = 0 self._last_ts = ts return self._build_id(ts, self._sequence) def generate_batch(self, count: int) -> list[int]: """Generate multiple IDs efficiently.""" if count <= 0: return [] with self._lock: ids = [] ts = self._get_time_ms() for _ in range(count): if ts < self._last_ts: raise RuntimeError("Clock regression detected") if ts == self._last_ts: self._sequence = (self._sequence + 1) & self.MAX_SEQUENCE if self._sequence == 0: ts = self._next_timestamp() else: self._sequence = 0 self._last_ts = ts ids.append(self._build_id(ts, self._sequence)) return ids def parse(self, id_value: int) -> SnowflakeID: """Parse an ID into its components.""" seq = id_value & self.MAX_SEQUENCE machine = (id_value >> self.MACHINE_SHIFT) & self.MAX_MACHINE_ID ts = (id_value >> self.TIMESTAMP_SHIFT) + self.EPOCH return SnowflakeID(value=id_value, timestamp_ms=ts, machine_id=machine, sequence=seq)if __name__ == "__main__": gen = BatchIDGenerator(machine_id=7) # Single generation test id1 = gen.generate() id2 = gen.generate() assert id1 < id2, "IDs must be monotonically increasing" # Batch generation test batch = gen.generate_batch(100) assert len(batch) == 100, "Batch size must match" assert len(set(batch)) == 100, "All IDs in batch must be unique" assert batch == sorted(batch), "Batch IDs must be ordered" assert all(b > id2 for b in batch), "Batch IDs must be greater than previous" # Parse test parsed = gen.parse(batch[50]) assert parsed.machine_id == 7, "Machine ID must match" assert parsed.sequence >= 0, "Sequence must be non-negative" print(f"Parsed ID: {parsed}") # SnowflakeID comparison sf1 = gen.parse(batch[0]) sf2 = gen.parse(batch[1]) assert sf1 < sf2, "SnowflakeID comparison must work" print(f"Generated {len(batch) + 2} unique ordered IDs") print("All tests passed!")
Question 43 - Sudoku Solver
Difficulty: 4 / 10
Approximate lines of code: 80 LoC
Tags: algorithms, game/simulation
Description
A Sudoku solver fills in a 9x9 grid such that each row, column, and 3x3 box contains the digits 1-9 exactly once. This is a classic constraint satisfaction problem solved with backtracking: try a value, check if it violates constraints, recurse or backtrack. The naive approach tries all values 1-9 for each empty cell, but optimizations like MRV (Minimum Remaining Values) and constraint propagation dramatically reduce the search space.
The core data structures are: (1) a 9x9 grid with 0 representing empty cells, and (2) constraint tracking (which values are used in each row, column, and box). The box index for cell (r, c) is computed as 3 * (r // 3) + c // 3.
Part A: Basic Backtracking Solver
Problem: Part A
Implement a solver using simple backtracking. Find an empty cell, try values 1-9, check validity, and recurse. Backtrack when no valid value exists.
Instead of picking the first empty cell, pick the cell with the fewest valid candidates. This prunes the search tree earlier - if a cell has only one valid option, fill it immediately.
# Before MRV: always pick first empty cell (0, 2)# Candidates for (0, 2): {1, 2, 4} - 3 options# With MRV: find cell with minimum candidates# Cell (0, 5): {8} - only 1 option! Fill it immediately.# Cell (1, 1): {2, 4, 7} - 3 options# Cell (2, 0): {1, 2} - 2 options# MRV picks (0, 5) first - fewer branches to exploresolver = SudokuSolver(puzzle)cell = solver.find_best_cell() # Returns cell with minimum candidates# If any cell has 0 candidates, puzzle is unsolvable from this state
Part C: Find All Solutions / Constraint Propagation
Problem: Part C
Extend to find all solutions (some puzzles have multiple). Add constraint propagation: when you place a value, immediately eliminate it from candidates in the same row, column, and box. If any cell becomes empty (no candidates), backtrack early.
# Finding all solutionssolver = SudokuSolver(puzzle_with_multiple_solutions)all_solutions = solver.solve(find_all=True)print(f"Found {len(all_solutions)} solutions")# Constraint propagation when placing 5 at (0, 0):# Before: candidates[0][1] = {5, 7, 8}# After: candidates[0][1] = {7, 8} # 5 eliminated (same row)# Early termination: if candidates[r][c] becomes empty set,# this branch is unsolvable - backtrack immediately
Interview comments
Interview comments
Edge cases to probe:
What if the input puzzle is already invalid (duplicate in row/col/box)?
What if the puzzle has no solution?
How do you handle a completely empty grid? (Many solutions)
How do you restore state when backtracking?
Common mistakes:
Box index calculation: 3 * (row // 3) + col // 3 - easy to get wrong
Not restoring grid state when backtracking
Checking validity after placement instead of before
Shallow copy vs deep copy when saving state for backtracking
Off-by-one: checking column with row loop variable
Code solutions
Code solutions
Solution 1 implements basic backtracking with validity checking, finding the first empty cell and trying values 1-9. Solution 2 adds constraint propagation and the MRV (Minimum Remaining Values) heuristic, maintaining candidate sets for each cell. Solution 3 uses bitmask-based constraint tracking for O(1) validity checks with bit manipulation. The key difference is how constraints are tracked and checked: iterative validation vs candidate sets vs bitmasks. Core techniques: backtracking with state restoration, box index formula (3*(r//3) + c//3), MRV heuristic, bitwise operations for constraint masks.
Solution 1: Basic backtracking
Basic backtracking with validity checking. find_empty() returns first empty cell. is_valid_placement() checks row, column, and 3x3 box constraints. Supports finding all solutions via flag.
"""Sudoku Solver - Solution 1: Basic Backtracking with Constraint Checking"""from dataclasses import dataclassfrom typing import Optional@dataclassclass SudokuBoard: grid: list[list[int]] # 0 represents empty cell def is_valid_placement(self, row: int, col: int, num: int) -> bool: """Check if placing num at (row, col) is valid.""" # Check row if num in self.grid[row]: return False # Check column if num in (self.grid[r][col] for r in range(9)): return False # Check 3x3 box box_row, box_col = 3 * (row // 3), 3 * (col // 3) for r in range(box_row, box_row + 3): for c in range(box_col, box_col + 3): if self.grid[r][c] == num: return False return True def find_empty(self) -> Optional[tuple[int, int]]: """Find the next empty cell (contains 0).""" for r in range(9): for c in range(9): if self.grid[r][c] == 0: return (r, c) return None def is_valid_board(self) -> bool: """Validate current board state has no conflicts.""" for r in range(9): for c in range(9): if self.grid[r][c] != 0: num = self.grid[r][c] self.grid[r][c] = 0 # Temporarily remove valid = self.is_valid_placement(r, c, num) self.grid[r][c] = num # Restore if not valid: return False return Truedef solve(board: SudokuBoard, find_all: bool = False) -> list[list[list[int]]]: """Solve sudoku using backtracking. Returns list of solutions.""" solutions: list[list[list[int]]] = [] def backtrack() -> bool: empty = board.find_empty() if empty is None: solutions.append([row[:] for row in board.grid]) return not find_all # Stop if only finding one row, col = empty for num in range(1, 10): if board.is_valid_placement(row, col, num): board.grid[row][col] = num if backtrack(): return True board.grid[row][col] = 0 return False if not board.is_valid_board(): return [] backtrack() return solutions
Solution 2: Constraint propagation with MRV
Backtracking with constraint propagation and MRV heuristic. Maintains a candidates grid (sets of possible values for each cell). When a value is placed, it’s eliminated from all peers. find_best_cell() returns the cell with minimum candidates.
"""Sudoku Solver - Solution 2: Backtracking with Constraint Propagation"""from dataclasses import dataclass, fieldfrom typing import Optional@dataclassclass SudokuSolver: grid: list[list[int]] candidates: list[list[set[int]]] = field(init=False) def __post_init__(self) -> None: self.candidates = [[set(range(1, 10)) for _ in range(9)] for _ in range(9)] self._initialize_constraints() def _initialize_constraints(self) -> None: """Remove candidates based on initial board state.""" for r in range(9): for c in range(9): if self.grid[r][c] != 0: self.candidates[r][c] = set() self._eliminate(r, c, self.grid[r][c]) def _eliminate(self, row: int, col: int, num: int) -> None: """Remove num from candidates in same row, column, and box.""" for c in range(9): self.candidates[row][c].discard(num) for r in range(9): self.candidates[r][col].discard(num) box_r, box_c = 3 * (row // 3), 3 * (col // 3) for r in range(box_r, box_r + 3): for c in range(box_c, box_c + 3): self.candidates[r][c].discard(num) def is_valid(self) -> bool: """Check if board is still solvable (no empty cells with no candidates).""" for r in range(9): for c in range(9): if self.grid[r][c] == 0 and len(self.candidates[r][c]) == 0: return False return True def find_best_cell(self) -> Optional[tuple[int, int]]: """Find empty cell with minimum remaining candidates (MRV heuristic).""" best, min_count = None, 10 for r in range(9): for c in range(9): if self.grid[r][c] == 0 and len(self.candidates[r][c]) < min_count: best, min_count = (r, c), len(self.candidates[r][c]) return best def solve(self, find_all: bool = False) -> list[list[list[int]]]: """Solve using backtracking with constraint propagation.""" solutions: list[list[list[int]]] = [] def backtrack() -> bool: cell = self.find_best_cell() if cell is None: solutions.append([row[:] for row in self.grid]) return not find_all row, col = cell saved_candidates = [row[:] for row in self.candidates] for num in list(self.candidates[row][col]): self.grid[row][col] = num self.candidates[row][col] = set() self._eliminate(row, col, num) if self.is_valid() and backtrack(): return True self.grid[row][col] = 0 self.candidates = [r[:] for r in saved_candidates] return False backtrack() return solutions
Solution 3: Bitmask-based constraint tracking
Bitmask-based constraint tracking for O(1) validity checks. Uses three arrays of 9-bit integers (rows, cols, boxes) where bit N is set if value N is used. get_candidates() computes valid values via bitwise OR and complement. Most efficient for large numbers of solve attempts.
"""Sudoku Solver - Solution 3: Bitmask-based Constraint Tracking"""from dataclasses import dataclass, fieldfrom typing import Optional@dataclassclass BitSudoku: grid: list[list[int]] rows: list[int] = field(default_factory=list) # Bitmask of used numbers per row cols: list[int] = field(default_factory=list) # Bitmask of used numbers per column boxes: list[int] = field(default_factory=list) # Bitmask of used numbers per 3x3 box _valid: bool = True @classmethod def from_grid(cls, grid: list[list[int]]) -> "BitSudoku": rows, cols, boxes = [0] * 9, [0] * 9, [0] * 9 valid = True for r in range(9): for c in range(9): if grid[r][c] != 0: bit = 1 << grid[r][c] box_idx = 3 * (r // 3) + c // 3 # Check for conflicts before adding if (rows[r] | cols[c] | boxes[box_idx]) & bit: valid = False rows[r] |= bit cols[c] |= bit boxes[box_idx] |= bit result = cls([row[:] for row in grid], rows, cols, boxes, valid) return result def get_candidates(self, r: int, c: int) -> list[int]: """Get valid candidates for cell using bitmasks.""" used = self.rows[r] | self.cols[c] | self.boxes[3 * (r // 3) + c // 3] return [n for n in range(1, 10) if not (used & (1 << n))] def find_mrv_cell(self) -> Optional[tuple[int, int, list[int]]]: """Find empty cell with minimum remaining values.""" best: Optional[tuple[int, int, list[int]]] = None for r in range(9): for c in range(9): if self.grid[r][c] == 0: cands = self.get_candidates(r, c) if not cands: return (r, c, []) # Early exit: unsolvable if best is None or len(cands) < len(best[2]): best = (r, c, cands) return best def place(self, r: int, c: int, num: int) -> None: bit = 1 << num self.grid[r][c] = num self.rows[r] |= bit self.cols[c] |= bit self.boxes[3 * (r // 3) + c // 3] |= bit def remove(self, r: int, c: int, num: int) -> None: bit = 1 << num self.grid[r][c] = 0 self.rows[r] ^= bit self.cols[c] ^= bit self.boxes[3 * (r // 3) + c // 3] ^= bit def solve(self, find_all: bool = False) -> list[list[list[int]]]: if not self._valid: return [] solutions: list[list[list[int]]] = [] def backtrack() -> bool: cell = self.find_mrv_cell() if cell is None: solutions.append([row[:] for row in self.grid]) return not find_all r, c, candidates = cell if not candidates: return False for num in candidates: self.place(r, c, num) if backtrack(): return True self.remove(r, c, num) return False backtrack() return solutions
Question 44 - Circuit Breaker
Difficulty: 9 / 10
Approximate lines of code: 80 LoC
Tags: distributed-systems
Description
The circuit breaker pattern prevents cascading failures when calling unreliable external services. Instead of repeatedly hammering a failing service, the circuit breaker “trips” after detecting failures and fails fast for a cooldown period. This protects both your system (avoiding blocked threads waiting for timeouts) and the downstream service (giving it time to recover).
The pattern uses a three-state machine: CLOSED (normal operation, tracking failures), OPEN (failing fast, rejecting all calls), and HALF-OPEN (testing if the service recovered). The key data structures are failure counters/timestamps and the current state. The state transitions lazily on each call attempt rather than via background timers.
Part A: Basic State Machine
Problem: Part A
Implement a circuit breaker with three states and count-based tripping. After N consecutive failures in CLOSED state, trip to OPEN. In OPEN, reject calls immediately with a CircuitOpenError. After a timeout period elapses, transition to HALF-OPEN and allow one probe call through. If it succeeds, go to CLOSED. If it fails, go back to OPEN.
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=30.0)# Normal operation (CLOSED)result = cb.call(lambda: external_service()) # Works normally# After 3 consecutive failures, circuit trips to OPENfor _ in range(3): try: cb.call(lambda: raise_error()) except ServiceError: pass# Internal state: _state=OPEN, _failure_count=3, _last_failure_time=<now># Now calls fail fast without invoking the functioncb.call(lambda: external_service()) # Raises CircuitOpenError immediately# After recovery_timeout seconds, state becomes HALF_OPEN on next call attempt# One successful call transitions back to CLOSED
Part B: Rate-Based Thresholds
Problem: Part B
Instead of counting consecutive failures, trip when the failure rate exceeds a threshold within a sliding time window. This is more robust for high-traffic services where a few scattered failures are acceptable. Track calls in a deque with (timestamp, success) tuples. Only evaluate the rate after min_calls have been made to avoid tripping on startup noise.
cb = CircuitBreaker( failure_rate_threshold=0.5, # 50% failure rate window_size=60.0, # 60 second sliding window min_calls=10, # Need at least 10 calls before evaluating recovery_timeout=30.0)# Internal state tracks recent calls:# _calls = deque([# (1704067200.0, True), # success# (1704067201.0, False), # failure# (1704067202.0, True), # success# ...# ])# _failure_rate() prunes old calls, then computes:# failures = sum(1 for _, success in _calls if not success)# rate = failures / len(_calls)# In HALF_OPEN, require multiple successful calls before closing# (single success could be a fluke)
Part C: Thread Safety and Metrics
Problem: Part C
Make the circuit breaker thread-safe and add state transition callbacks for observability. Use locks around state mutations and the call tracking data structures. Callbacks fire on CLOSED→OPEN, OPEN→HALF_OPEN, and HALF_OPEN→CLOSED transitions.
What happens if the test call in HALF_OPEN fails? (Should go back to OPEN, not CLOSED)
Do you reset failure count on success in CLOSED state? (Yes, otherwise one success after 2 failures doesn’t reset)
What time function do you use? (time.monotonic() not time.time() to avoid clock jumps)
How do you handle the exception after recording the failure? (Must re-raise it)
Common mistakes:
Forgetting to re-raise exceptions after recording failure
Not resetting failure count on success in CLOSED state
Using time.time() instead of time.monotonic() (system clock can jump backwards)
State transition check only in dedicated method, not lazily on call attempt
In rate-based: not having a minimum call threshold before evaluating rate
In HALF_OPEN: transitioning to CLOSED after just one success (should test multiple times)
Code solutions
Code solutions
Solution 1 uses a simple count-based approach with consecutive failure tracking. Solution 2 implements sliding window rate-based tripping using a deque of timestamped call results. Solution 3 employs the State pattern with explicit state classes for each circuit state. The key difference is how they detect failures: fixed count vs. failure rate percentage, and how they structure state transitions: implicit property-based vs. explicit state objects.
Core techniques: state machine, sliding window, decorator pattern, lazy evaluation.
Solution 1: Simple count-based implementation
Simple count-based implementation using a dataclass. State is checked lazily via a property that transitions OPEN→HALF_OPEN when timeout expires. Uses consecutive failure counting with reset on success.
"""Circuit Breaker Pattern - Solution 1: Simple Count-Based ImplementationUses a fixed failure count threshold to trip the circuit."""from dataclasses import dataclass, fieldfrom enum import Enumfrom typing import Callable, TypeVarimport timeT = TypeVar("T")class State(Enum): CLOSED = "closed" # Normal operation OPEN = "open" # Failing, reject all calls HALF_OPEN = "half_open" # Testing if service recovered@dataclassclass CircuitBreaker: failure_threshold: int = 5 recovery_timeout: float = 30.0 _state: State = field(default=State.CLOSED, init=False) _failure_count: int = field(default=0, init=False) _last_failure_time: float = field(default=0.0, init=False) @property def state(self) -> State: if self._state == State.OPEN: if time.time() - self._last_failure_time >= self.recovery_timeout: self._state = State.HALF_OPEN return self._state def call(self, func: Callable[[], T]) -> T: if self.state == State.OPEN: raise CircuitOpenError("Circuit is open") try: result = func() self._on_success() return result except Exception as e: self._on_failure() raise e def _on_success(self) -> None: self._failure_count = 0 self._state = State.CLOSED def _on_failure(self) -> None: self._failure_count += 1 self._last_failure_time = time.time() if self._failure_count >= self.failure_threshold: self._state = State.OPENclass CircuitOpenError(Exception): pass
Solution 2: Sliding window rate-based implementation
Sliding window rate-based implementation. Stores calls in a deque with timestamps and success flags. Prunes old calls outside the window before calculating failure rate. Requires minimum calls before evaluating and multiple successes in HALF_OPEN.
"""Circuit Breaker Pattern - Solution 2: Sliding Window Rate-BasedUses a time-based sliding window to calculate failure rate."""from dataclasses import dataclass, fieldfrom enum import Enumfrom typing import Callable, TypeVarfrom collections import dequeimport timeT = TypeVar("T")class State(Enum): CLOSED = "closed" OPEN = "open" HALF_OPEN = "half_open"@dataclassclass CircuitBreaker: failure_rate_threshold: float = 0.5 # 50% failure rate window_size: float = 60.0 # 60 second window min_calls: int = 5 # Minimum calls before tripping recovery_timeout: float = 30.0 half_open_max_calls: int = 3 # Calls to test in half-open _state: State = field(default=State.CLOSED, init=False) _calls: deque = field(default_factory=deque, init=False) # (timestamp, success) _opened_at: float = field(default=0.0, init=False) _half_open_successes: int = field(default=0, init=False) @property def state(self) -> State: if self._state == State.OPEN: if time.time() - self._opened_at >= self.recovery_timeout: self._state = State.HALF_OPEN self._half_open_successes = 0 return self._state def _prune_old_calls(self) -> None: cutoff = time.time() - self.window_size while self._calls and self._calls[0][0] < cutoff: self._calls.popleft() def _failure_rate(self) -> float: self._prune_old_calls() if len(self._calls) < self.min_calls: return 0.0 failures = sum(1 for _, success in self._calls if not success) return failures / len(self._calls) def call(self, func: Callable[[], T]) -> T: if self.state == State.OPEN: raise CircuitOpenError("Circuit is open") try: result = func() self._record_success() return result except Exception as e: self._record_failure() raise e def _record_success(self) -> None: now = time.time() if self._state == State.HALF_OPEN: self._half_open_successes += 1 if self._half_open_successes >= self.half_open_max_calls: self._state = State.CLOSED self._calls.clear() else: self._calls.append((now, True)) def _record_failure(self) -> None: now = time.time() if self._state == State.HALF_OPEN: self._state = State.OPEN self._opened_at = now else: self._calls.append((now, False)) if self._failure_rate() >= self.failure_rate_threshold: self._state = State.OPEN self._opened_at = nowclass CircuitOpenError(Exception): pass
Solution 3: State pattern implementation
State pattern implementation with explicit state classes. Each state (ClosedState, OpenState, HalfOpenState) handles calls differently. Also provides decorator interfaces for protecting functions.
A load balancer distributes incoming requests across multiple backend servers to improve throughput and availability. Different strategies optimize for different goals: round-robin ensures even distribution, least-connections routes to the server with the lightest current load, and weighted distribution gives more traffic to higher-capacity servers.
Internally, you track a list of servers with their health status and connection counts. The core operation is route_request() which selects a server based on the current strategy, increments its connection count, and returns it. When the request completes, release_connection() decrements the count.
Part A: Round Robin
Problem: Part A
Implement round-robin load balancing that cycles through healthy servers in order. Maintain an index that advances on each request. Skip unhealthy servers.
lb = LoadBalancer(strategy=RoundRobin())lb.add_server(Server("A"))lb.add_server(Server("B"))lb.add_server(Server("C"))# Internal state:# servers: [A, B, C], all healthy# rr_index: 0lb.route_request() # Returns Server A, index -> 1lb.route_request() # Returns Server B, index -> 2lb.route_request() # Returns Server C, index -> 0 (wraps)lb.route_request() # Returns Server A, index -> 1lb.set_health("B", False)lb.route_request() # Skips B, returns A or C based on current index
Part B: Least Connections
Problem: Part B
Implement least-connections routing that always selects the healthy server with the fewest active connections. Track connection counts, incrementing on route and decrementing on release.
lb = LoadBalancer(strategy=LeastConnections())lb.add_server(Server("X", active_connections=5))lb.add_server(Server("Y", active_connections=2))lb.add_server(Server("Z", active_connections=8))# Internal state:# X: 5 connections, Y: 2 connections, Z: 8 connectionslb.route_request() # Returns Y (fewest: 2), Y now has 3lb.route_request() # Returns Y again (still fewest: 3), Y now has 4lb.route_request() # Returns Y (4), Y now has 5lb.route_request() # Returns X or Y (tied at 5)lb.release_connection("Y") # Y back to 4
Part C: Health Checks and Weighted Distribution
Problem: Part C
Add health status tracking - unhealthy servers should never receive requests. Implement weighted distribution where servers with higher weights receive proportionally more traffic.
lb = LoadBalancer(strategy=Weighted())lb.add_server(Server("heavy", weight=3))lb.add_server(Server("light", weight=1))# Over 8 requests, "heavy" should get ~6, "light" should get ~2# (3:1 ratio based on weights)results = [lb.route_request().id for _ in range(8)]# Results like: ["heavy", "heavy", "heavy", "light", "heavy", "heavy", "heavy", "light"]# Health check filteringlb.set_health("heavy", False)lb.route_request() # Only "light" available, always returns it
Interview comments
Interview comments
Edge cases to probe:
What if all servers are unhealthy? (Return None or raise exception)
What happens to round-robin index when a server is removed? (May skip or wrap incorrectly)
What if weighted servers have weight 0? (Should never be selected)
What about negative connection counts from buggy release calls? (Clamp to 0)
Common mistakes:
Round-robin index overflow or incorrect modulo after server removal
Not filtering unhealthy servers before selection
Forgetting to increment connection count on route
Weighted distribution that doesn’t actually respect ratios (often reset logic is wrong)
Checking health at wrong time (should check at selection time, not add time)
Code solutions
Code solutions
Solutions Overview
Solution 1 uses the Strategy pattern with separate classes for RoundRobin, LeastConnections, and Weighted routing. Solution 2 takes a functional approach with enums and a dispatch dictionary, using itertools.cycle for round-robin. Solution 3 is a direct implementation with explicit methods per strategy and optional callback support. These vary in their design patterns: OOP polymorphism vs functional dispatch vs simple methods. Core techniques: strategy pattern, modulo cycling, min-heap selection, weighted round-robin.
Solution 1: Strategy Pattern
Object-oriented approach with Strategy pattern. Defines a RoutingStrategy protocol with separate classes for RoundRobin, LeastConnections, and Weighted. Clean separation of concerns.
"""Load Balancer - Solution 1: Object-Oriented with Strategy Pattern"""from dataclasses import dataclass, fieldfrom typing import Protocol, Optionalfrom abc import abstractmethod@dataclassclass Server: id: str weight: int = 1 is_healthy: bool = True active_connections: int = 0class RoutingStrategy(Protocol): @abstractmethod def select_server(self, servers: list[Server]) -> Optional[Server]: ...@dataclassclass RoundRobinStrategy: _index: int = 0 def select_server(self, servers: list[Server]) -> Optional[Server]: healthy = [s for s in servers if s.is_healthy] if not healthy: return None server = healthy[self._index % len(healthy)] self._index = (self._index + 1) % len(healthy) return server@dataclassclass LeastConnectionsStrategy: def select_server(self, servers: list[Server]) -> Optional[Server]: healthy = [s for s in servers if s.is_healthy] if not healthy: return None return min(healthy, key=lambda s: s.active_connections)@dataclassclass WeightedStrategy: _counter: dict[str, int] = field(default_factory=dict) def select_server(self, servers: list[Server]) -> Optional[Server]: healthy = [s for s in servers if s.is_healthy] if not healthy: return None for s in healthy: if s.id not in self._counter: self._counter[s.id] = 0 selected = max(healthy, key=lambda s: s.weight - self._counter[s.id]) self._counter[selected.id] += 1 total_weight = sum(s.weight for s in healthy) if all(self._counter.get(s.id, 0) >= s.weight for s in healthy): self._counter = {s.id: 0 for s in healthy} return selected@dataclassclass LoadBalancer: strategy: RoutingStrategy servers: list[Server] = field(default_factory=list) def add_server(self, server: Server) -> None: self.servers.append(server) def remove_server(self, server_id: str) -> bool: for i, s in enumerate(self.servers): if s.id == server_id: self.servers.pop(i) return True return False def set_health(self, server_id: str, is_healthy: bool) -> None: for s in self.servers: if s.id == server_id: s.is_healthy = is_healthy return def route_request(self) -> Optional[Server]: server = self.strategy.select_server(self.servers) if server: server.active_connections += 1 return server def release_connection(self, server_id: str) -> None: for s in self.servers: if s.id == server_id and s.active_connections > 0: s.active_connections -= 1 return
Solution 2: Functional with Enum Dispatch
Functional approach using an enum for strategy selection and a dispatch dictionary. Uses itertools.cycle for round-robin. All routing logic is methods on the LoadBalancer class.
"""Load Balancer - Solution 2: Functional approach with enums"""from dataclasses import dataclass, fieldfrom enum import Enum, autofrom typing import Optional, Callablefrom itertools import cycleclass Strategy(Enum): ROUND_ROBIN = auto() LEAST_CONNECTIONS = auto() WEIGHTED = auto()@dataclassclass Server: id: str weight: int = 1 is_healthy: bool = True active_connections: int = 0@dataclassclass LoadBalancer: strategy: Strategy _servers: dict[str, Server] = field(default_factory=dict) _rr_cycle: Optional[cycle] = field(default=None, repr=False) _weight_tracker: dict[str, int] = field(default_factory=dict) def add_server(self, server: Server) -> None: self._servers[server.id] = server self._reset_rr_cycle() def remove_server(self, server_id: str) -> bool: if server_id in self._servers: del self._servers[server_id] self._weight_tracker.pop(server_id, None) self._reset_rr_cycle() return True return False def set_health(self, server_id: str, is_healthy: bool) -> None: if server_id in self._servers: self._servers[server_id].is_healthy = is_healthy self._reset_rr_cycle() def _reset_rr_cycle(self) -> None: healthy = self._get_healthy_servers() self._rr_cycle = cycle(healthy) if healthy else None def _get_healthy_servers(self) -> list[Server]: return [s for s in self._servers.values() if s.is_healthy] def _select_round_robin(self, healthy: list[Server]) -> Optional[Server]: if not self._rr_cycle: return None for _ in range(len(self._servers)): server = next(self._rr_cycle) if server.is_healthy: return server return None def _select_least_connections(self, healthy: list[Server]) -> Optional[Server]: return min(healthy, key=lambda s: s.active_connections) if healthy else None def _select_weighted(self, healthy: list[Server]) -> Optional[Server]: if not healthy: return None for s in healthy: self._weight_tracker.setdefault(s.id, 0) selected = max(healthy, key=lambda s: s.weight - self._weight_tracker[s.id]) self._weight_tracker[selected.id] += 1 if all(self._weight_tracker[s.id] >= s.weight for s in healthy): self._weight_tracker = {s.id: 0 for s in healthy} return selected def route_request(self) -> Optional[Server]: healthy = self._get_healthy_servers() if not healthy: return None selectors: dict[Strategy, Callable] = { Strategy.ROUND_ROBIN: self._select_round_robin, Strategy.LEAST_CONNECTIONS: self._select_least_connections, Strategy.WEIGHTED: self._select_weighted, } server = selectors[self.strategy](healthy) if server: server.active_connections += 1 return server def release_connection(self, server_id: str) -> None: if server_id in self._servers: s = self._servers[server_id] s.active_connections = max(0, s.active_connections - 1)
Solution 3: Direct Methods with Callbacks
Simple direct implementation with explicit method for each strategy. Uses heapq.nsmallest for least-connections. Includes optional callback support for routing events.
"""Load Balancer - Solution 3: Simple, direct implementation with callbacks"""from dataclasses import dataclass, fieldfrom typing import Optional, Callableimport heapq@dataclassclass Server: id: str weight: int = 1 is_healthy: bool = True active_connections: int = 0 def __lt__(self, other: "Server") -> bool: return self.active_connections < other.active_connections@dataclassclass LoadBalancer: _servers: list[Server] = field(default_factory=list) _rr_index: int = 0 _weight_remaining: dict[str, int] = field(default_factory=dict) _on_route: Optional[Callable[[Server], None]] = None def add_server(self, server: Server) -> None: self._servers.append(server) self._weight_remaining[server.id] = server.weight def remove_server(self, server_id: str) -> bool: for i, s in enumerate(self._servers): if s.id == server_id: self._servers.pop(i) self._weight_remaining.pop(server_id, None) return True return False def set_health(self, server_id: str, healthy: bool) -> None: for s in self._servers: if s.id == server_id: s.is_healthy = healthy break def _healthy(self) -> list[Server]: return [s for s in self._servers if s.is_healthy] def round_robin(self) -> Optional[Server]: healthy = self._healthy() if not healthy: return None self._rr_index %= len(healthy) server = healthy[self._rr_index] self._rr_index = (self._rr_index + 1) % len(healthy) return server def least_connections(self) -> Optional[Server]: healthy = self._healthy() if not healthy: return None return heapq.nsmallest(1, healthy)[0] def weighted(self) -> Optional[Server]: healthy = self._healthy() if not healthy: return None for s in healthy: if s.id not in self._weight_remaining or self._weight_remaining[s.id] <= 0: self._weight_remaining[s.id] = s.weight if all(self._weight_remaining.get(s.id, 0) <= 0 for s in healthy): for s in healthy: self._weight_remaining[s.id] = s.weight best = max(healthy, key=lambda s: self._weight_remaining.get(s.id, 0)) self._weight_remaining[best.id] -= 1 return best def route(self, strategy: str) -> Optional[Server]: strategies = { "round_robin": self.round_robin, "least_connections": self.least_connections, "weighted": self.weighted, } server = strategies[strategy]() if server: server.active_connections += 1 if self._on_route: self._on_route(server) return server def release(self, server_id: str) -> None: for s in self._servers: if s.id == server_id: s.active_connections = max(0, s.active_connections - 1) break