Following Multivariate Chain Rule, we end up with a large “Chain” of gradients that we have to determine.

This can be tiresome if you do it by hand, which is why people follow a Derive as you Go pattern.

Setup

We saw from previously that the chain rule in Multivariate Calculus contains some annoying summations and component-wise analysis that makes it hard to derive gradients in one shot.

Suppose we have: Where:

  • is your parameter
  • is the first intermediate quantity
  • is the second intermediate quantity
  • is the third intermediate quantity
  • is the final scalar

We apply chain rule repeatedly, but as separate summations

Step 1: From to

Step 2: From to

Step 3: From to

If you substitute Step 1 into Step 2 into Step 3:

How to Computational Graph

The Key Insight

Instead of deriving the full nested sum all at once, we can build a graph structure where:

  • Nodes represent intermediate values ()
  • Edges represent dependencies (how outputs depend on inputs)

Then we compute gradients by traversing the graph backwards.

Doing it by hand

Given a network architecture:

  • Sketch the forward pass
  • Begin at Loss
  • Compute Layer Gradients
    • Compute local gradients
    • Apply chain rule
    • pass the input gradient up to the next layer
  • Repeat until you get to the top

Forward Pass: Build the Graph

As we compute the forward pass, we:

  1. Store each intermediate value
  2. Record the operations used to create them
  3. Build edges showing what depends on what

Example:

θ → f₁ → f₂ → f₃ → L

Backward Pass: Traverse the Graph

Starting from , we propagate gradients backwards through each edge:

Step 1: Initialize

Step 2: From to

Step 3: From to (using chain rule)

Step 4: From to

Step 5: From to

At each step, we:

  1. Receive gradient from the next layer:
  2. Compute local gradient:
  3. Apply chain rule:
  4. Pass gradient to previous layer

Branching: When Multiple Paths Exist

If a node has multiple children (is used in multiple places), we sum the gradients from all paths:

     ┌→ f₂ →┐
θ → f₁       ├→ L
     └→ f₃ →┘

This automatically handles the summation in the chain rule!

Node Types in the Graph

1. Leaf Nodes (Parameters)

  • Nodes like (weights, biases)
  • Require gradients - these are what we want to update
  • .requires_grad = True in PyTorch

2. Intermediate Nodes (Activations)

  • Nodes like
  • Store values during forward pass
  • Compute and pass gradients during backward pass
  • Can be freed after backward pass to save memory

3. Output Node (Loss)

  • Node like
  • Starting point for backward pass
  • Always has gradient = 1

Operations Store Local Gradients

Each operation in the graph knows how to compute its local gradient:

OperationForwardBackward (local gradient)
,
,
,