Multivariate Chain Rule
Given a composition of functions:
Where:
- is your parameter (can be scalar, vector, matrix, tensor - any shape)
- is some intermediate function(can output scalar, vector, matrix, tensor - any shape)
- is the final output (can be scalar, vector, matrix, tensor - any shape)
The universal multivariable chain rule is given as:
To translate: In order to determine how changes with an element , we need to sum up the contributions of all the elements of that depend on .
Why “Sum over all indices of “?
Because is the intermediate quantity that connects to :
- Each element depends on
- depends on each element
- To find how depends on , you add up contributions from all the
Examples
Both are vectors
- (parameter vector)
- (intermediate vector)
- (final scalar)
Sum over all components of the intermediate quantity .
Parameter is matrix, intermediate is vector
- (parameter matrix)
- (intermediate vector)
- (final scalar)
Sum over all components of the intermediate quantity .
Parameter is matrix, intermediate is matrix
- (parameter matrix)
- (intermediate matrix)
- (final scalar)
Sum over all components of the intermediate quantity .
How does chain rule work here?
Suppose we have:
Where:
- is your parameter
- is the first intermediate quantity
- is the second intermediate quantity
- is the third intermediate quantity
- is the final scalar
We apply chain rule repeatedly, but as separate summations
Step 1: From to
Step 2: From to
Step 3: From to
If you substitute Step 1 into Step 2 into Step 3:
Of course, deriving like this is a nightmare, which is why we need some sort of easier way to do it. For deep learning, we can follow a derive as you go pattern. In computing, this is usually done with some graph structure like the PyTorch Computational Graph.
Once you get to a certain point of writing out the component representation of the gradients with the given functions, that’s when you hit a wall of “how do I actually simplify this?“. See The Art of Simplifying Multivariate Gradients
