Following Multivariate Chain Rule, we end up with a large “Chain” of gradients that we have to determine.
This can be tiresome if you do it by hand, which is why people follow a Derive as you Go pattern.
Setup
We saw from previously that the chain rule in Multivariate Calculus contains some annoying summations and component-wise analysis that makes it hard to derive gradients in one shot.
Suppose we have: Where:
- is your parameter
- is the first intermediate quantity
- is the second intermediate quantity
- is the third intermediate quantity
- is the final scalar
We apply chain rule repeatedly, but as separate summations
Step 1: From to
Step 2: From to
Step 3: From to
If you substitute Step 1 into Step 2 into Step 3:
How to Computational Graph
The Key Insight
Instead of deriving the full nested sum all at once, we can build a graph structure where:
- Nodes represent intermediate values ()
- Edges represent dependencies (how outputs depend on inputs)
Then we compute gradients by traversing the graph backwards.
Doing it by hand
Given a network architecture:
- Sketch the forward pass
- Begin at Loss
- Compute Layer Gradients
- Compute local gradients
- Apply chain rule
- pass the input gradient up to the next layer
- Repeat until you get to the top
Forward Pass: Build the Graph
As we compute the forward pass, we:
- Store each intermediate value
- Record the operations used to create them
- Build edges showing what depends on what
Example:
θ → f₁ → f₂ → f₃ → L
Backward Pass: Traverse the Graph
Starting from , we propagate gradients backwards through each edge:
Step 1: Initialize
Step 2: From to
Step 3: From to (using chain rule)
Step 4: From to
Step 5: From to
At each step, we:
- Receive gradient from the next layer:
- Compute local gradient:
- Apply chain rule:
- Pass gradient to previous layer
Branching: When Multiple Paths Exist
If a node has multiple children (is used in multiple places), we sum the gradients from all paths:
┌→ f₂ →┐
θ → f₁ ├→ L
└→ f₃ →┘
This automatically handles the summation in the chain rule!
Node Types in the Graph
1. Leaf Nodes (Parameters)
- Nodes like (weights, biases)
- Require gradients - these are what we want to update
.requires_grad = Truein PyTorch
2. Intermediate Nodes (Activations)
- Nodes like
- Store values during forward pass
- Compute and pass gradients during backward pass
- Can be freed after backward pass to save memory
3. Output Node (Loss)
- Node like
- Starting point for backward pass
- Always has gradient = 1
Operations Store Local Gradients
Each operation in the graph knows how to compute its local gradient:
| Operation | Forward | Backward (local gradient) |
|---|---|---|
| , | ||
| , | ||
| , |
