\(\def\mymacro{{\mathbf{\alpha,\beta,\gamma}}}\)
\(\def\va{{\mathbf{a}}}\)
\(\def\vb{{\mathbf{b}}}\)
\(\def\vc{{\mathbf{c}}}\)
\(\def\vd{{\mathbf{d}}}\)
\(\def\ve{{\mathbf{e}}}\)
\(\def\vf{{\mathbf{f}}}\)
\(\def\vg{{\mathbf{g}}}\)
\(\def\vh{{\mathbf{h}}}\)
\(\def\vi{{\mathbf{i}}}\)
\(\def\vj{{\mathbf{j}}}\)
\(\def\vk{{\mathbf{k}}}\)
\(\def\vl{{\mathbf{l}}}\)
\(\def\vm{{\mathbf{m}}}\)
\(\def\vn{{\mathbf{n}}}\)
\(\def\vo{{\mathbf{o}}}\)
\(\def\vp{{\mathbf{p}}}\)
\(\def\vq{{\mathbf{q}}}\)
\(\def\vr{{\mathbf{r}}}\)
\(\def\vs{{\mathbf{s}}}\)
\(\def\vt{{\mathbf{t}}}\)
\(\def\vu{{\mathbf{u}}}\)
\(\def\vv{{\mathbf{v}}}\)
\(\def\vw{{\mathbf{w}}}\)
\(\def\vx{{\mathbf{x}}}\)
\(\def\vy{{\mathbf{y}}}\)
\(\def\vz{{\mathbf{z}}}\)
\(\def\vmu{{\mathbf{\mu}}}\)
\(\def\vtheta{{\mathbf{\theta}}}\)
\(\def\vzero{{\mathbf{0}}}\)
\(\def\vone{{\mathbf{1}}}\)
\(\def\vell{{\mathbf{\ell}}}\)
\(\def\mA{{\mathbf{A}}}\)
\(\def\mB{{\mathbf{B}}}\)
\(\def\mC{{\mathbf{C}}}\)
\(\def\mD{{\mathbf{D}}}\)
\(\def\mE{{\mathbf{E}}}\)
\(\def\mF{{\mathbf{F}}}\)
\(\def\mG{{\mathbf{G}}}\)
\(\def\mH{{\mathbf{H}}}\)
\(\def\mI{{\mathbf{I}}}\)
\(\def\mJ{{\mathbf{J}}}\)
\(\def\mK{{\mathbf{K}}}\)
\(\def\mL{{\mathbf{L}}}\)
\(\def\mM{{\mathbf{M}}}\)
\(\def\mN{{\mathbf{N}}}\)
\(\def\mO{{\mathbf{O}}}\)
\(\def\mP{{\mathbf{P}}}\)
\(\def\mQ{{\mathbf{Q}}}\)
\(\def\mR{{\mathbf{R}}}\)
\(\def\mS{{\mathbf{S}}}\)
\(\def\mT{{\mathbf{T}}}\)
\(\def\mU{{\mathbf{U}}}\)
\(\def\mV{{\mathbf{V}}}\)
\(\def\mW{{\mathbf{W}}}\)
\(\def\mX{{\mathbf{X}}}\)
\(\def\mY{{\mathbf{Y}}}\)
\(\def\mZ{{\mathbf{Z}}}\)
\(\def\mStilde{\mathbf{\tilde{\mS}}}\)
\(\def\mGtilde{\mathbf{\tilde{\mG}}}\)
\(\def\mGoverline{{\mathbf{\overline{G}}}}\)
\(\def\mBeta{{\mathbf{\beta}}}\)
\(\def\mPhi{{\mathbf{\Phi}}}\)
\(\def\mLambda{{\mathbf{\Lambda}}}\)
\(\def\mSigma{{\mathbf{\Sigma}}}\)
\(\def\gA{{\mathcal{A}}}\)
\(\def\gB{{\mathcal{B}}}\)
\(\def\gC{{\mathcal{C}}}\)
\(\def\gD{{\mathcal{D}}}\)
\(\def\gE{{\mathcal{E}}}\)
\(\def\gF{{\mathcal{F}}}\)
\(\def\gG{{\mathcal{G}}}\)
\(\def\gH{{\mathcal{H}}}\)
\(\def\gI{{\mathcal{I}}}\)
\(\def\gJ{{\mathcal{J}}}\)
\(\def\gK{{\mathcal{K}}}\)
\(\def\gL{{\mathcal{L}}}\)
\(\def\gM{{\mathcal{M}}}\)
\(\def\gN{{\mathcal{N}}}\)
\(\def\gO{{\mathcal{O}}}\)
\(\def\gP{{\mathcal{P}}}\)
\(\def\gQ{{\mathcal{Q}}}\)
\(\def\gR{{\mathcal{R}}}\)
\(\def\gS{{\mathcal{S}}}\)
\(\def\gT{{\mathcal{T}}}\)
\(\def\gU{{\mathcal{U}}}\)
\(\def\gV{{\mathcal{V}}}\)
\(\def\gW{{\mathcal{W}}}\)
\(\def\gX{{\mathcal{X}}}\)
\(\def\gY{{\mathcal{Y}}}\)
\(\def\gZ{{\mathcal{Z}}}\)
\(\def\sA{{\mathbb{A}}}\)
\(\def\sB{{\mathbb{B}}}\)
\(\def\sC{{\mathbb{C}}}\)
\(\def\sD{{\mathbb{D}}}\)
\(\def\sF{{\mathbb{F}}}\)
\(\def\sG{{\mathbb{G}}}\)
\(\def\sH{{\mathbb{H}}}\)
\(\def\sI{{\mathbb{I}}}\)
\(\def\sJ{{\mathbb{J}}}\)
\(\def\sK{{\mathbb{K}}}\)
\(\def\sL{{\mathbb{L}}}\)
\(\def\sM{{\mathbb{M}}}\)
\(\def\sN{{\mathbb{N}}}\)
\(\def\sO{{\mathbb{O}}}\)
\(\def\sP{{\mathbb{P}}}\)
\(\def\sQ{{\mathbb{Q}}}\)
\(\def\sR{{\mathbb{R}}}\)
\(\def\sS{{\mathbb{S}}}\)
\(\def\sT{{\mathbb{T}}}\)
\(\def\sU{{\mathbb{U}}}\)
\(\def\sV{{\mathbb{V}}}\)
\(\def\sW{{\mathbb{W}}}\)
\(\def\sX{{\mathbb{X}}}\)
\(\def\sY{{\mathbb{Y}}}\)
\(\def\sZ{{\mathbb{Z}}}\)
\(\def\E{{\mathbb{E}}}\)
\(\def\jac{{\mathbf{\mathrm{J}}}}\)
\(\def\argmax{{\mathop{\mathrm{arg}\,\mathrm{max}}}}\)
\(\def\argmin{{\mathop{\mathrm{arg}\,\mathrm{min}}}}\)
\(\def\Tr{{\mathop{\mathrm{Tr}}}}\)
\(\def\diag{{\mathop{\mathrm{diag}}}}\)
\(\def\vec{{\mathop{\mathrm{vec}}}}\)

Hessian row sum in PyTorch

Table of Contents

You can grab the code as a python script here.

A friend recently asked me how to compute the row-wise sum of the Hessian matrix in PyTorch. This can be done with a single Hessian-vector product and is thus a neat use case for that functionality.

I will first demonstrate Hessian-vector products in PyTorch on a simple example before moving on to a more popular example with neural networks.

Let's get the imports out of our way.

from torch import Tensor, ones_like, allclose, cat, rand, manual_seed
from torch.nn import Linear, MSELoss, ReLU, Sequential
from backpack.hessianfree.hvp import hessian_vector_product
from typing import List

Simple example

As a sanity check, let's quickly walk through a low-dimensional example.

Consider the scalar function

\begin{equation*} f(x_1, x_2) = x_{1}^{2} + 2 x_{1}^{2} x_{2} - 3 x_{2}^{3}\,, \qquad x_{1,2}\in\sR\,. \end{equation*}

Its Hessian \(\mH(x_{1}, x_{2})\) is

\begin{equation*} \mH(x_1, x_2) = \begin{pmatrix} \frac{\partial^2 f}{\partial x_1^2} & \frac{\partial^2 f}{\partial x_1 \partial x_2} \\ \frac{\partial^2 f}{\partial x_2 \partial x_1} & \frac{\partial^2 f}{\partial x_2^2} \end{pmatrix} = \begin{pmatrix} 2 + 4x_{2} & 4x_1 \\ 4x_1 & -18x_2 \end{pmatrix} \end{equation*}

and the row-wise sum results from a multiplication with a vector of ones,

\begin{equation*} \begin{pmatrix} 2 + 4x_{2} + 4x_1 \\ 4x_1 - 18x_2 \end{pmatrix} = \mH(x_1, x_2) \begin{pmatrix} 1 \\ 1 \end{pmatrix}\,. \end{equation*}

If we choose \((x_{1} = 1, x_{2} = 2)\), then the result should be \(\begin{pmatrix} 14 \\ -32 \end{pmatrix}\). Let's reproduce that in code:

x1 = Tensor([1.0])
x2 = Tensor([2.0])

x1.requires_grad = True
x2.requires_grad = True

expected = Tensor([14.0, -32.0])

f = x1 ** 2 + 2 * x1 ** 2 * x2 - 3 * x2 ** 3

The Hessian-vector product takes a vector in list format (same shapes as f's variables):

parameters = [x1, x2]
v_list = [ones_like(param) for param in parameters]
Hv_list = hessian_vector_product(f, parameters, v_list)

print(Hv_list)
(tensor([14.]), tensor([-32.]))

The output also has list format which we need to undo before comparing:

Hv = cat([v.flatten() for v in Hv_list])

if allclose(Hv, expected):
    print("Results match.")
else:
    print("Results don't match.")
Results match.

Neural network example

The above example works for arbitrary computation graphs. But many cases consider the Hessian of a neural network's loss function.

To not repeat myself, I'll condense everything into a utility that you can copy:

def hessian_row_sum(f: Tensor, parameters: List[Tensor]) -> Tensor:
    """Compute the Hessian row-wise sum and return it as 1d tensor.

    Args:
        f: Function of the Hessian.
        parameters: Variables of f. Need to have ``requires_grad=True``.

    Returns:
        Hessian row sum as vector.
    """
    Hv_list = hessian_vector_product(
        f, parameters, [ones_like(param) for param in parameters]
    )

    return cat([v.flatten() for v in Hv_list])

Here's a simple neural net with random data:

batch_size = 3
D_in = 5
D_out = 4

X = rand(batch_size, D_in)
y = rand(batch_size, D_out)

model = Sequential(Linear(D_in, D_out, bias=False), ReLU())
loss_func = MSELoss()

output = model(X)
loss = loss_func(output, y)

Let's call our utility:

result = hessian_row_sum(loss, list(model.parameters()))

print(result.shape)
torch.Size([20])

Enjoy!

Author: Felix Dangel

Created: 2021-09-13 Mon 16:10

Validate