# Expanding `einsum`

expressions

## Table of Contents

You can grab the code as a `python`

script here.

`einsum`

is great! You can conveniently express common linear algebra operations like matrix-vector products, and more complicated tensor contractions. It helps me to write more readable code, especially in the context of machine learning.

An `einsum`

expression consists of two ingredients: an equation that specifies the contraction, and the input tensors.

As an example, let `A`

, `x`

be a matrix and a vector of appropriate size. Then, the matrix-vector product `y = A @ x`

is

y = einsum("ik,k->i", A, x)

What if `A`

was formed by two other matrices `B`

, `C`

via `A = B @ C`

? We could write

A = einsum("ij,jk->ik", B, C)

and expand the `einsum`

for `y`

into

y = einsum("ij,jk,k->i", B, C, x)

Doing that takes mental effort. And there are additional challenges not covered by this example, like inserting an expression whose indices are already used.

To automate such `einsum`

expansions, I'll code an `einsum_expand`

helper function in this post. For the above example, it should work as follows:

A = ["ij,jk->ik", B, C] y = ["ik,k->i", A, x] einsum_expand(y) # -> ["ij,jk,k->i", B, C, x] (or equivalent)

Let's get started! (Imports first)

import re from torch import rand, einsum, allclose, manual_seed, Tensor from opt_einsum import contract, contract_expression from typing import List from timeit import repeat manual_seed(0)

## Motivation

First, some thoughts how such a function could be useful.

Say we want to evaluate a nested `einsum`

expression without our `einsum_expand`

helper. We could do one of the following:

**Do the expansion by hand:**For instance, we could draw the expanded tensor network, introduce indices, and read off the contraction. But this does not scale too large expressions.**Naively evaluate all unnested expressions with**Iteratively, this will reduce the nesting until we're left with a single unnested expression. This constrains execution order. But this order is important for performance! Contraction optimization packages like`einsum`

:`opt_einsum`

can find the best contraction schedule and greatly improve performance. Check out this`opt_einsum`

demo.

`einsum_expand`

scales to arbitrarily complicated nested `einsum`

expressions and in contrast to naive evaluation, its output can be optimized by `opt_einsum`

and computed faster in many scenarios.

## Introduction

### Toy example

I will use a slightly more involved example: Let `A, B, C, D, E`

be matrices of matching dimensions and `A = B @ C`

where `C = D @ E`

:

B = rand(10, 5) D = rand(5, 20) E = rand(20, 30)

Let's first compute `A`

naively with the following two `einsum`

expressions:

C = einsum("ij,jk->ik", D, E) A_naive = einsum("ij,jk->ik", B, C)

Combining the expressions into a single one, `A = B @ D @ E`

,

A_expanded = einsum("ij,jk,kl->il", B, D, E)

yields the same result:

print(allclose(A_naive, A_expanded))

True

### Naive evaluation

I claimed that expanding `einsum`

expressions can lead to speed-ups when combined with contraction optimizers. Here is the naive evaluation scheme that we will use for comparison later.

def is_tensor(x): return isinstance(x, Tensor) def is_nested(x): return isinstance(x, (tuple, list)) def is_equation(x): return isinstance(x, str) def einsum_nested_naive(expression, optimize=False) -> Tensor: """Naively evaluate a nested einsum expression. Evaluate unnested expressions until no sub-expressions remains. Args: expression: List describing a (potentially nested) einsum expression. The head is an equation, and the tail consists of tensors or more einsum expressions. optimize: If ``True``, use ``opt_einsum.contract`` to evaluate unnested expressions. Otherwise, use ``torch.einsum``. Default: ``False``. Returns: Result tensor. """ equation, operands = expression[0], expression[1:] if not is_equation(equation): raise ValueError(f"Invalid equation as first entry: {equation}.") operands_flat = [] for op in operands: if is_tensor(op): operands_flat.append(op) elif is_nested(op): operands_flat.append(einsum_nested_naive(op, optimize=optimize)) else: raise ValueError(f"Expect Tensor or einsum expression, got {op}.") contraction_func = contract if optimize else einsum return contraction_func(equation, *operands_flat)

Let's verify it on our toy example:

A = [ "ij,jk->ik", B, ["ij,jk->ik", D, E], ] A_nested_naive = einsum_nested_naive(A) print(allclose(A_nested_naive, A_naive))

True

## Helper functions

### Working with `einsum`

equations

To insert one sub-expression, we have to replace its output indices by group of input indices. For that we need to be able to split equations into per-operand groups and combine them back into one string. This is done by the following helpers:

def equation_to_groups(equation: str) -> List[str]: return re.split(",|->", equation) def groups_to_equation(groups: List[str]) -> str: return "->".join([",".join(groups[:-1]), groups[-1]])

They are best understood by examples:

equation = "ij,jk,kl->il" groups = equation_to_groups(equation) print(f"Groups: {groups}") print(f"Equation from groups: '{groups_to_equation(groups)}'")

Groups: ['ij', 'jk', 'kl', 'il'] Equation from groups: 'ij,jk,kl->il'

### Renaming indices

The aforementioned replacement operation has to respect additional rules: As we're introducing the input group's indices, we need to make sure their names are not already occupied. We also need to match the output indices of the equation we want to insert.

Let's walk through our example to understand some issues we might encounter:

A = ["ij,jk->ik", B ["ij,jk->ik", C, D]]

We want to replace `jk`

in `A[0]`

by the sub-expression `A[-1]`

.

The latter currently has output indices `ik`

. So we need to rename `i`

into `j`

. But `j`

is already occupied by a summation index, so we need to rename that first. This will lead to a rename of

"ij,jk->ik"

into

"ja,ak->jk"

We can now safely expand the renamed

A = ["ij,jk->ik", B ["ja,ak->jk", C, D]]

into

["ij,ja,ak->ik", B, C, D]

In the last step we could have encountered another problem: If `a`

was already taken in `A[0]`

, we would have to do another rename.

The renaming is done by the following helper `rename_out_indices`

:

def rename_out_indices(equation, new_name, blocked=None): """Rename output indices into new_name, respecting blocked indices.""" out_indices = equation_to_groups(equation)[-1] blocked = set() if blocked is None else blocked.copy() blocked = blocked.union(set(new_name)).union(set(out_indices)) equation = _rename_indices(equation, blocked=blocked) out_indices = equation_to_groups(equation)[-1] for out_idx, new_idx in zip(out_indices, new_name): equation = equation.replace(out_idx, new_idx) return equation def _rename_indices(equation, blocked=None): """Rename indices of an equation if they are blocked.""" groups = equation_to_groups(equation) indices = set("".join(groups)) blocked = set() if blocked is None else blocked.copy() blocked = blocked.union(indices) rename_indices = indices.intersection(blocked) for idx, new_name in zip(rename_indices, _get_free_characters(blocked=blocked)): equation = equation.replace(idx, new_name) return equation def _get_free_characters(blocked=None): """Yield characters that are not in blocked, starting with a.""" blocked = set() if blocked is None else blocked first = ord("a") shift = 0 while True: char = chr(first + shift) if char not in blocked: yield char shift += 1

Here is how it acts on our example:

print(rename_out_indices("ij,jk->ik", "jk"))

ja,ak->jk

We can block certain letters from being used, too:

print(rename_out_indices("ij,jk->ik", "jk", blocked={"c"}))

ja,ak->jk

### Inserting at a given position

We now have all the functionality to correctly rename and insert an expression.

def insert_at(expression, pos): """In-place insert the sub-expression at the given position.""" equation, operands = expression[0], expression[1:] groups = equation_to_groups(equation) blocked = set("".join(groups)) inner_equation, inner_operands = operands[pos][0], list(operands[pos][1:]) inner_equation = rename_out_indices(inner_equation, groups[pos], blocked=blocked) inner_groups = equation_to_groups(inner_equation) out_operands = operands[:pos] + inner_operands + operands[pos + 1 :] out_groups = groups[:pos] + inner_groups[:-1] + groups[pos + 1 :] out_equation = groups_to_equation(out_groups) return [out_equation] + out_operands

Here are some examples:

Inserting an unnested sub-expression

# strings instead of tensors for more readable printing expression = ["ij,jk->ik", "B", ["ij,jk->ik", "D", "E"]] print(insert_at(expression, 1))

['ij,ja,ak->ik', 'B', 'D', 'E']

Inserting a nested sub-expression:

# strings instead of tensors for more readable printing expression = ["ij,jk->ik", "B", ["ij,jk->ik", ["ab,bk->ak", "D", "E"], "F"]] print(insert_at(expression, 1)) print(insert_at(insert_at(expression, 1), 1))

['ij,ja,ak->ik', 'B', ['ab,bk->ak', 'D', 'E'], 'F'] ['ij,jd,da,ak->ik', 'B', 'D', 'E', 'F']

## Putting it together

Here is the `einsum_expand`

helper I promised in the beginning:

def einsum_expand(expression): """Expand a nested list of ``einsum`` expressions into a single one. An ``einsum`` expression is defined via a tuple/list ``(equation, *operands)`` where operands are either ``Tensor``s or ``einsum`` expressions. Args: expression: List describing a (potentially nested) einsum expression. The head is an equation, and the tail consists of tensors or more einsum expressions. Returns: An unnested expression consisting of an equation and its operands. """ expression = list(expression) def expandable_pos(expression): for pos, subexpression in enumerate(expression[1:]): if is_nested(subexpression): return pos raise ValueError("No expandable expression found.") while any(is_nested(subexpression) for subexpression in expression[1:]): pos = expandable_pos(expression) expression = insert_at(expression, pos) return expression

Let's try it out on our toy example. We compare the result of naively contracting the nested expression with contracting the expanded expression with a usual `einsum`

:

# einsum_expand handles list- and tuple-valued expressions A = ("ij,jk->ik", B, ["ij,jk->ik", D, E]) A_naive = einsum_nested_naive(A) A_expanded = einsum(*einsum_expand(A)) print(allclose(A_naive, A_expanded))

True

## Use case

`einsum_expand`

can be useful in combination with contraction optimizers like `opt_einsum`

. To demonstrate this, we consider a modified version of this `opt_einsum`

demo. But I replace the matrix `C`

by the matrix product `A @ B`

to make the expression nested:

dim = 30 inner_dim = dim // 10 I = rand(dim, dim, dim, dim) A = rand(dim, inner_dim) B = rand(inner_dim, dim) C = ["ij,jk->ik", A, B] expression = ["pi,qj,ijkl,rk,sl->pqrs", C, C, I, C, C]

We will benchmark different combinations of contraction implementations (`opt_einsum.contract`

versus `torch.einsum`

) and evaluation schedules (`naive`

versus `expand`

). Here are their definitions:

def naive_torch(): """Naively evaluate with `torch.einsum`.""" return einsum_nested_naive(expression, optimize=False) def naive_opt_einsum(): """Naively evaluate with `opt_einsum.contract`.""" return einsum_nested_naive(expression, optimize=True) # NOTE Left out for faster execution # def expand_torch(): # """Expand expression then evaluate with `torch.einsum`.""" # return einsum(*einsum_expand(expression)) def expand_opt_einsum(): """Expand expression then evaluate with `opt_einsum.contract`.""" return contract(*einsum_expand(expression))

Let's make sure they produce the same result:

print( allclose(naive_torch(), naive_opt_einsum()) # NOTE Left out for faster execution # and allclose(naive_torch(), expand_torch()) and allclose(naive_torch(), expand_opt_einsum()) )

True

Let's compare their performance:

REPEAT, NUMBER = 10, 20 naive_torch_time = min(repeat(naive_torch, repeat=REPEAT, number=NUMBER)) # NOTE Left out for faster execution # expand_torch_time = min(repeat(expand_torch, repeat=REPEAT, number=NUMBER)) naive_opt_einsum_time = min(repeat(naive_opt_einsum, repeat=REPEAT, number=NUMBER)) expand_opt_einsum_time = min(repeat(expand_opt_einsum, repeat=REPEAT, number=NUMBER)) best_overall = min( naive_torch_time, # NOTE Left out for faster execution # expand_torch_time, naive_opt_einsum_time, expand_opt_einsum_time, ) print( f"Naive + torch.einsum : {naive_torch_time:.5f} s " + f"(x {naive_torch_time / best_overall:.2f})" ) # NOTE Left out for faster execution # print( # f"Expand + torch.einsum : {expand_torch_time:.5f} s " # + f"(x {expand_torch_time / best_overall:.2f})" # ) print( f"Naive + opt_einsum.contract : {naive_opt_einsum_time:.5f} s " + f"(x {naive_opt_einsum_time / best_overall:.2f})" ) print( f"Expand + opt_einsum.contract : {expand_opt_einsum_time:.5f} s " + f"(x {expand_opt_einsum_time / best_overall:.2f})" )

Naive + torch.einsum : 0.20648 s (x 4.97) Naive + opt_einsum.contract : 0.07125 s (x 1.71) Expand + opt_einsum.contract : 0.04157 s (x 1.00)

We can see that combining `einsum_expand`

with `opt_einsum`

performs best.

That's all for now. Enjoy!