Multivariate Chain Rule

Given a composition of functions:

Where:

  • is your parameter (can be scalar, vector, matrix, tensor - any shape)
  • is some intermediate function(can output scalar, vector, matrix, tensor - any shape)
  • is the final output (can be scalar, vector, matrix, tensor - any shape)

The universal multivariable chain rule is given as:

To translate: In order to determine how changes with an element , we need to sum up the contributions of all the elements of that depend on .

Why “Sum over all indices of “?

Because is the intermediate quantity that connects to :

  • Each element depends on
  • depends on each element
  • To find how depends on , you add up contributions from all the

Examples

Both are vectors

  • (parameter vector)
  • (intermediate vector)
  • (final scalar)

Sum over all components of the intermediate quantity .

Parameter is matrix, intermediate is vector

  • (parameter matrix)
  • (intermediate vector)
  • (final scalar)

Sum over all components of the intermediate quantity .

Parameter is matrix, intermediate is matrix

  • (parameter matrix)
  • (intermediate matrix)
  • (final scalar)

Sum over all components of the intermediate quantity .

How does chain rule work here?

Suppose we have:

Where:

  • is your parameter
  • is the first intermediate quantity
  • is the second intermediate quantity
  • is the third intermediate quantity
  • is the final scalar

We apply chain rule repeatedly, but as separate summations

Step 1: From to

Step 2: From to

Step 3: From to

If you substitute Step 1 into Step 2 into Step 3:

Of course, deriving like this is a nightmare, which is why we need some sort of easier way to do it. For deep learning, we can follow a derive as you go pattern. In computing, this is usually done with some graph structure like the PyTorch Computational Graph.

Once you get to a certain point of writing out the component representation of the gradients with the given functions, that’s when you hit a wall of “how do I actually simplify this?“. See The Art of Simplifying Multivariate Gradients