\(\def\mymacro{{\mathbf{\alpha,\beta,\gamma}}}\)
\(\def\va{{\mathbf{a}}}\)
\(\def\vb{{\mathbf{b}}}\)
\(\def\vc{{\mathbf{c}}}\)
\(\def\vd{{\mathbf{d}}}\)
\(\def\ve{{\mathbf{e}}}\)
\(\def\vf{{\mathbf{f}}}\)
\(\def\vg{{\mathbf{g}}}\)
\(\def\vh{{\mathbf{h}}}\)
\(\def\vi{{\mathbf{i}}}\)
\(\def\vj{{\mathbf{j}}}\)
\(\def\vk{{\mathbf{k}}}\)
\(\def\vl{{\mathbf{l}}}\)
\(\def\vm{{\mathbf{m}}}\)
\(\def\vn{{\mathbf{n}}}\)
\(\def\vo{{\mathbf{o}}}\)
\(\def\vp{{\mathbf{p}}}\)
\(\def\vq{{\mathbf{q}}}\)
\(\def\vr{{\mathbf{r}}}\)
\(\def\vs{{\mathbf{s}}}\)
\(\def\vt{{\mathbf{t}}}\)
\(\def\vu{{\mathbf{u}}}\)
\(\def\vv{{\mathbf{v}}}\)
\(\def\vw{{\mathbf{w}}}\)
\(\def\vx{{\mathbf{x}}}\)
\(\def\vy{{\mathbf{y}}}\)
\(\def\vz{{\mathbf{z}}}\)
\(\def\vmu{{\mathbf{\mu}}}\)
\(\def\vtheta{{\mathbf{\theta}}}\)
\(\def\vzero{{\mathbf{0}}}\)
\(\def\vone{{\mathbf{1}}}\)
\(\def\vell{{\mathbf{\ell}}}\)
\(\def\mA{{\mathbf{A}}}\)
\(\def\mB{{\mathbf{B}}}\)
\(\def\mC{{\mathbf{C}}}\)
\(\def\mD{{\mathbf{D}}}\)
\(\def\mE{{\mathbf{E}}}\)
\(\def\mF{{\mathbf{F}}}\)
\(\def\mG{{\mathbf{G}}}\)
\(\def\mH{{\mathbf{H}}}\)
\(\def\mI{{\mathbf{I}}}\)
\(\def\mJ{{\mathbf{J}}}\)
\(\def\mK{{\mathbf{K}}}\)
\(\def\mL{{\mathbf{L}}}\)
\(\def\mM{{\mathbf{M}}}\)
\(\def\mN{{\mathbf{N}}}\)
\(\def\mO{{\mathbf{O}}}\)
\(\def\mP{{\mathbf{P}}}\)
\(\def\mQ{{\mathbf{Q}}}\)
\(\def\mR{{\mathbf{R}}}\)
\(\def\mS{{\mathbf{S}}}\)
\(\def\mT{{\mathbf{T}}}\)
\(\def\mU{{\mathbf{U}}}\)
\(\def\mV{{\mathbf{V}}}\)
\(\def\mW{{\mathbf{W}}}\)
\(\def\mX{{\mathbf{X}}}\)
\(\def\mY{{\mathbf{Y}}}\)
\(\def\mZ{{\mathbf{Z}}}\)
\(\def\mStilde{\mathbf{\tilde{\mS}}}\)
\(\def\mGtilde{\mathbf{\tilde{\mG}}}\)
\(\def\mGoverline{{\mathbf{\overline{G}}}}\)
\(\def\mBeta{{\mathbf{\beta}}}\)
\(\def\mPhi{{\mathbf{\Phi}}}\)
\(\def\mLambda{{\mathbf{\Lambda}}}\)
\(\def\mSigma{{\mathbf{\Sigma}}}\)
\(\def\gA{{\mathcal{A}}}\)
\(\def\gB{{\mathcal{B}}}\)
\(\def\gC{{\mathcal{C}}}\)
\(\def\gD{{\mathcal{D}}}\)
\(\def\gE{{\mathcal{E}}}\)
\(\def\gF{{\mathcal{F}}}\)
\(\def\gG{{\mathcal{G}}}\)
\(\def\gH{{\mathcal{H}}}\)
\(\def\gI{{\mathcal{I}}}\)
\(\def\gJ{{\mathcal{J}}}\)
\(\def\gK{{\mathcal{K}}}\)
\(\def\gL{{\mathcal{L}}}\)
\(\def\gM{{\mathcal{M}}}\)
\(\def\gN{{\mathcal{N}}}\)
\(\def\gO{{\mathcal{O}}}\)
\(\def\gP{{\mathcal{P}}}\)
\(\def\gQ{{\mathcal{Q}}}\)
\(\def\gR{{\mathcal{R}}}\)
\(\def\gS{{\mathcal{S}}}\)
\(\def\gT{{\mathcal{T}}}\)
\(\def\gU{{\mathcal{U}}}\)
\(\def\gV{{\mathcal{V}}}\)
\(\def\gW{{\mathcal{W}}}\)
\(\def\gX{{\mathcal{X}}}\)
\(\def\gY{{\mathcal{Y}}}\)
\(\def\gZ{{\mathcal{Z}}}\)
\(\def\sA{{\mathbb{A}}}\)
\(\def\sB{{\mathbb{B}}}\)
\(\def\sC{{\mathbb{C}}}\)
\(\def\sD{{\mathbb{D}}}\)
\(\def\sF{{\mathbb{F}}}\)
\(\def\sG{{\mathbb{G}}}\)
\(\def\sH{{\mathbb{H}}}\)
\(\def\sI{{\mathbb{I}}}\)
\(\def\sJ{{\mathbb{J}}}\)
\(\def\sK{{\mathbb{K}}}\)
\(\def\sL{{\mathbb{L}}}\)
\(\def\sM{{\mathbb{M}}}\)
\(\def\sN{{\mathbb{N}}}\)
\(\def\sO{{\mathbb{O}}}\)
\(\def\sP{{\mathbb{P}}}\)
\(\def\sQ{{\mathbb{Q}}}\)
\(\def\sR{{\mathbb{R}}}\)
\(\def\sS{{\mathbb{S}}}\)
\(\def\sT{{\mathbb{T}}}\)
\(\def\sU{{\mathbb{U}}}\)
\(\def\sV{{\mathbb{V}}}\)
\(\def\sW{{\mathbb{W}}}\)
\(\def\sX{{\mathbb{X}}}\)
\(\def\sY{{\mathbb{Y}}}\)
\(\def\sZ{{\mathbb{Z}}}\)
\(\def\E{{\mathbb{E}}}\)
\(\def\jac{{\mathbf{\mathrm{J}}}}\)
\(\def\argmax{{\mathop{\mathrm{arg}\,\mathrm{max}}}}\)
\(\def\argmin{{\mathop{\mathrm{arg}\,\mathrm{min}}}}\)
\(\def\Tr{{\mathop{\mathrm{Tr}}}}\)
\(\def\diag{{\mathop{\mathrm{diag}}}}\)
\(\def\vec{{\mathop{\mathrm{vec}}}}\)
Hessian row sum in PyTorch
You can grab the code as a python
script here.
A friend recently asked me how to compute the row-wise sum of the Hessian matrix in PyTorch. This can be done with a single Hessian-vector product and is thus a neat use case for that functionality.
I will first demonstrate Hessian-vector products in PyTorch on a simple example before moving on to a more popular example with neural networks.
Let's get the imports out of our way.
from torch import Tensor, ones_like, allclose, cat, rand, manual_seed
from torch.nn import Linear, MSELoss, ReLU, Sequential
from backpack.hessianfree.hvp import hessian_vector_product
from typing import List
As a sanity check, let's quickly walk through a low-dimensional example.
Consider the scalar function
\begin{equation*}
f(x_1, x_2) = x_{1}^{2} + 2 x_{1}^{2} x_{2} - 3 x_{2}^{3}\,, \qquad x_{1,2}\in\sR\,.
\end{equation*}
Its Hessian \(\mH(x_{1}, x_{2})\) is
\begin{equation*}
\mH(x_1, x_2)
=
\begin{pmatrix}
\frac{\partial^2 f}{\partial x_1^2} & \frac{\partial^2 f}{\partial x_1 \partial x_2}
\\
\frac{\partial^2 f}{\partial x_2 \partial x_1} & \frac{\partial^2 f}{\partial x_2^2}
\end{pmatrix}
=
\begin{pmatrix}
2 + 4x_{2} & 4x_1
\\
4x_1 & -18x_2
\end{pmatrix}
\end{equation*}
and the row-wise sum results from a multiplication with a vector of ones,
\begin{equation*}
\begin{pmatrix}
2 + 4x_{2} + 4x_1
\\
4x_1 - 18x_2
\end{pmatrix}
=
\mH(x_1, x_2)
\begin{pmatrix}
1
\\
1
\end{pmatrix}\,.
\end{equation*}
If we choose \((x_{1} = 1, x_{2} = 2)\), then the result should be \(\begin{pmatrix} 14 \\ -32 \end{pmatrix}\). Let's reproduce that in code:
x1 = Tensor([1.0])
x2 = Tensor([2.0])
x1.requires_grad = True
x2.requires_grad = True
expected = Tensor([14.0, -32.0])
f = x1 ** 2 + 2 * x1 ** 2 * x2 - 3 * x2 ** 3
The Hessian-vector product takes a vector in list
format (same shapes as f
's variables):
parameters = [x1, x2]
v_list = [ones_like(param) for param in parameters]
Hv_list = hessian_vector_product(f, parameters, v_list)
print(Hv_list)
(tensor([14.]), tensor([-32.]))
The output also has list
format which we need to undo before comparing:
Hv = cat([v.flatten() for v in Hv_list])
if allclose(Hv, expected):
print("Results match.")
else:
print("Results don't match.")
Results match.
The above example works for arbitrary computation graphs. But many cases consider the Hessian of a neural network's loss function.
To not repeat myself, I'll condense everything into a utility that you can copy:
def hessian_row_sum(f: Tensor, parameters: List[Tensor]) -> Tensor:
"""Compute the Hessian row-wise sum and return it as 1d tensor.
Args:
f: Function of the Hessian.
parameters: Variables of f. Need to have ``requires_grad=True``.
Returns:
Hessian row sum as vector.
"""
Hv_list = hessian_vector_product(
f, parameters, [ones_like(param) for param in parameters]
)
return cat([v.flatten() for v in Hv_list])
Here's a simple neural net with random data:
batch_size = 3
D_in = 5
D_out = 4
X = rand(batch_size, D_in)
y = rand(batch_size, D_out)
model = Sequential(Linear(D_in, D_out, bias=False), ReLU())
loss_func = MSELoss()
output = model(X)
loss = loss_func(output, y)
Let's call our utility:
result = hessian_row_sum(loss, list(model.parameters()))
print(result.shape)
torch.Size([20])
Enjoy!
Author: Felix Dangel
Created: 2021-09-13 Mon 16:10
Validate