$$\def\mymacro{{\mathbf{\alpha,\beta,\gamma}}}$$
$$\def\va{{\mathbf{a}}}$$
$$\def\vb{{\mathbf{b}}}$$
$$\def\vc{{\mathbf{c}}}$$
$$\def\vd{{\mathbf{d}}}$$
$$\def\ve{{\mathbf{e}}}$$
$$\def\vf{{\mathbf{f}}}$$
$$\def\vg{{\mathbf{g}}}$$
$$\def\vh{{\mathbf{h}}}$$
$$\def\vi{{\mathbf{i}}}$$
$$\def\vj{{\mathbf{j}}}$$
$$\def\vk{{\mathbf{k}}}$$
$$\def\vl{{\mathbf{l}}}$$
$$\def\vm{{\mathbf{m}}}$$
$$\def\vn{{\mathbf{n}}}$$
$$\def\vo{{\mathbf{o}}}$$
$$\def\vp{{\mathbf{p}}}$$
$$\def\vq{{\mathbf{q}}}$$
$$\def\vr{{\mathbf{r}}}$$
$$\def\vs{{\mathbf{s}}}$$
$$\def\vt{{\mathbf{t}}}$$
$$\def\vu{{\mathbf{u}}}$$
$$\def\vv{{\mathbf{v}}}$$
$$\def\vw{{\mathbf{w}}}$$
$$\def\vx{{\mathbf{x}}}$$
$$\def\vy{{\mathbf{y}}}$$
$$\def\vz{{\mathbf{z}}}$$
$$\def\vmu{{\mathbf{\mu}}}$$
$$\def\vtheta{{\mathbf{\theta}}}$$
$$\def\vzero{{\mathbf{0}}}$$
$$\def\vone{{\mathbf{1}}}$$
$$\def\vell{{\mathbf{\ell}}}$$
$$\def\mA{{\mathbf{A}}}$$
$$\def\mB{{\mathbf{B}}}$$
$$\def\mC{{\mathbf{C}}}$$
$$\def\mD{{\mathbf{D}}}$$
$$\def\mE{{\mathbf{E}}}$$
$$\def\mF{{\mathbf{F}}}$$
$$\def\mG{{\mathbf{G}}}$$
$$\def\mH{{\mathbf{H}}}$$
$$\def\mI{{\mathbf{I}}}$$
$$\def\mJ{{\mathbf{J}}}$$
$$\def\mK{{\mathbf{K}}}$$
$$\def\mL{{\mathbf{L}}}$$
$$\def\mM{{\mathbf{M}}}$$
$$\def\mN{{\mathbf{N}}}$$
$$\def\mO{{\mathbf{O}}}$$
$$\def\mP{{\mathbf{P}}}$$
$$\def\mQ{{\mathbf{Q}}}$$
$$\def\mR{{\mathbf{R}}}$$
$$\def\mS{{\mathbf{S}}}$$
$$\def\mT{{\mathbf{T}}}$$
$$\def\mU{{\mathbf{U}}}$$
$$\def\mV{{\mathbf{V}}}$$
$$\def\mW{{\mathbf{W}}}$$
$$\def\mX{{\mathbf{X}}}$$
$$\def\mY{{\mathbf{Y}}}$$
$$\def\mZ{{\mathbf{Z}}}$$
$$\def\mStilde{\mathbf{\tilde{\mS}}}$$
$$\def\mGtilde{\mathbf{\tilde{\mG}}}$$
$$\def\mGoverline{{\mathbf{\overline{G}}}}$$
$$\def\mBeta{{\mathbf{\beta}}}$$
$$\def\mPhi{{\mathbf{\Phi}}}$$
$$\def\mLambda{{\mathbf{\Lambda}}}$$
$$\def\mSigma{{\mathbf{\Sigma}}}$$
$$\def\gA{{\mathcal{A}}}$$
$$\def\gB{{\mathcal{B}}}$$
$$\def\gC{{\mathcal{C}}}$$
$$\def\gD{{\mathcal{D}}}$$
$$\def\gE{{\mathcal{E}}}$$
$$\def\gF{{\mathcal{F}}}$$
$$\def\gG{{\mathcal{G}}}$$
$$\def\gH{{\mathcal{H}}}$$
$$\def\gI{{\mathcal{I}}}$$
$$\def\gJ{{\mathcal{J}}}$$
$$\def\gK{{\mathcal{K}}}$$
$$\def\gL{{\mathcal{L}}}$$
$$\def\gM{{\mathcal{M}}}$$
$$\def\gN{{\mathcal{N}}}$$
$$\def\gO{{\mathcal{O}}}$$
$$\def\gP{{\mathcal{P}}}$$
$$\def\gQ{{\mathcal{Q}}}$$
$$\def\gR{{\mathcal{R}}}$$
$$\def\gS{{\mathcal{S}}}$$
$$\def\gT{{\mathcal{T}}}$$
$$\def\gU{{\mathcal{U}}}$$
$$\def\gV{{\mathcal{V}}}$$
$$\def\gW{{\mathcal{W}}}$$
$$\def\gX{{\mathcal{X}}}$$
$$\def\gY{{\mathcal{Y}}}$$
$$\def\gZ{{\mathcal{Z}}}$$
$$\def\sA{{\mathbb{A}}}$$
$$\def\sB{{\mathbb{B}}}$$
$$\def\sC{{\mathbb{C}}}$$
$$\def\sD{{\mathbb{D}}}$$
$$\def\sF{{\mathbb{F}}}$$
$$\def\sG{{\mathbb{G}}}$$
$$\def\sH{{\mathbb{H}}}$$
$$\def\sI{{\mathbb{I}}}$$
$$\def\sJ{{\mathbb{J}}}$$
$$\def\sK{{\mathbb{K}}}$$
$$\def\sL{{\mathbb{L}}}$$
$$\def\sM{{\mathbb{M}}}$$
$$\def\sN{{\mathbb{N}}}$$
$$\def\sO{{\mathbb{O}}}$$
$$\def\sP{{\mathbb{P}}}$$
$$\def\sQ{{\mathbb{Q}}}$$
$$\def\sR{{\mathbb{R}}}$$
$$\def\sS{{\mathbb{S}}}$$
$$\def\sT{{\mathbb{T}}}$$
$$\def\sU{{\mathbb{U}}}$$
$$\def\sV{{\mathbb{V}}}$$
$$\def\sW{{\mathbb{W}}}$$
$$\def\sX{{\mathbb{X}}}$$
$$\def\sY{{\mathbb{Y}}}$$
$$\def\sZ{{\mathbb{Z}}}$$
$$\def\E{{\mathbb{E}}}$$
$$\def\jac{{\mathbf{\mathrm{J}}}}$$
$$\def\argmax{{\mathop{\mathrm{arg}\,\mathrm{max}}}}$$
$$\def\argmin{{\mathop{\mathrm{arg}\,\mathrm{min}}}}$$
$$\def\Tr{{\mathop{\mathrm{Tr}}}}$$
$$\def\diag{{\mathop{\mathrm{diag}}}}$$
$$\def\vec{{\mathop{\mathrm{vec}}}}$$

# Hessian row sum in PyTorch

You can grab the code as a python script here.

A friend recently asked me how to compute the row-wise sum of the Hessian matrix in PyTorch. This can be done with a single Hessian-vector product and is thus a neat use case for that functionality.

I will first demonstrate Hessian-vector products in PyTorch on a simple example before moving on to a more popular example with neural networks.

Let's get the imports out of our way.

from torch import Tensor, ones_like, allclose, cat, rand, manual_seed
from torch.nn import Linear, MSELoss, ReLU, Sequential
from backpack.hessianfree.hvp import hessian_vector_product
from typing import List


## Simple example

As a sanity check, let's quickly walk through a low-dimensional example.

Consider the scalar function

\begin{equation*} f(x_1, x_2) = x_{1}^{2} + 2 x_{1}^{2} x_{2} - 3 x_{2}^{3}\,, \qquad x_{1,2}\in\sR\,. \end{equation*}

Its Hessian $$\mH(x_{1}, x_{2})$$ is

\begin{equation*} \mH(x_1, x_2) = \begin{pmatrix} \frac{\partial^2 f}{\partial x_1^2} & \frac{\partial^2 f}{\partial x_1 \partial x_2} \\ \frac{\partial^2 f}{\partial x_2 \partial x_1} & \frac{\partial^2 f}{\partial x_2^2} \end{pmatrix} = \begin{pmatrix} 2 + 4x_{2} & 4x_1 \\ 4x_1 & -18x_2 \end{pmatrix} \end{equation*}

and the row-wise sum results from a multiplication with a vector of ones,

\begin{equation*} \begin{pmatrix} 2 + 4x_{2} + 4x_1 \\ 4x_1 - 18x_2 \end{pmatrix} = \mH(x_1, x_2) \begin{pmatrix} 1 \\ 1 \end{pmatrix}\,. \end{equation*}

If we choose $$(x_{1} = 1, x_{2} = 2)$$, then the result should be $$\begin{pmatrix} 14 \\ -32 \end{pmatrix}$$. Let's reproduce that in code:

x1 = Tensor([1.0])
x2 = Tensor([2.0])

expected = Tensor([14.0, -32.0])

f = x1 ** 2 + 2 * x1 ** 2 * x2 - 3 * x2 ** 3


The Hessian-vector product takes a vector in list format (same shapes as f's variables):

parameters = [x1, x2]
v_list = [ones_like(param) for param in parameters]
Hv_list = hessian_vector_product(f, parameters, v_list)

print(Hv_list)

(tensor([14.]), tensor([-32.]))


The output also has list format which we need to undo before comparing:

Hv = cat([v.flatten() for v in Hv_list])

if allclose(Hv, expected):
print("Results match.")
else:
print("Results don't match.")

Results match.


## Neural network example

The above example works for arbitrary computation graphs. But many cases consider the Hessian of a neural network's loss function.

To not repeat myself, I'll condense everything into a utility that you can copy:

def hessian_row_sum(f: Tensor, parameters: List[Tensor]) -> Tensor:
"""Compute the Hessian row-wise sum and return it as 1d tensor.

Args:
f: Function of the Hessian.
parameters: Variables of f. Need to have requires_grad=True.

Returns:
Hessian row sum as vector.
"""
Hv_list = hessian_vector_product(
f, parameters, [ones_like(param) for param in parameters]
)

return cat([v.flatten() for v in Hv_list])


Here's a simple neural net with random data:

batch_size = 3
D_in = 5
D_out = 4

X = rand(batch_size, D_in)
y = rand(batch_size, D_out)

model = Sequential(Linear(D_in, D_out, bias=False), ReLU())
loss_func = MSELoss()

output = model(X)
loss = loss_func(output, y)


Let's call our utility:

result = hessian_row_sum(loss, list(model.parameters()))

print(result.shape)

torch.Size()


Enjoy!

Created: 2021-09-13 Mon 16:10

Validate