Expanding einsum expressions

Table of Contents

You can grab the code as a python script here.

einsum is great! You can conveniently express common linear algebra operations like matrix-vector products, and more complicated tensor contractions. It helps me to write more readable code, especially in the context of machine learning.

An einsum expression consists of two ingredients: an equation that specifies the contraction, and the input tensors.

As an example, let A, x be a matrix and a vector of appropriate size. Then, the matrix-vector product y = A @ x is

y = einsum("ik,k->i", A, x)

What if A was formed by two other matrices B, C via A = B @ C? We could write

A = einsum("ij,jk->ik", B, C)

and expand the einsum for y into

y = einsum("ij,jk,k->i", B, C, x)

Doing that takes mental effort. And there are additional challenges not covered by this example, like inserting an expression whose indices are already used.

To automate such einsum expansions, I'll code an einsum_expand helper function in this post. For the above example, it should work as follows:

A = ["ij,jk->ik", B, C]
y = ["ik,k->i", A, x]

einsum_expand(y) # -> ["ij,jk,k->i", B, C, x] (or equivalent)

Let's get started! (Imports first)

import re
from torch import rand, einsum, allclose, manual_seed, Tensor
from opt_einsum import contract, contract_expression
from typing import List
from timeit import repeat



First, some thoughts how such a function could be useful.

Say we want to evaluate a nested einsum expression without our einsum_expand helper. We could do one of the following:

  1. Do the expansion by hand: For instance, we could draw the expanded tensor network, introduce indices, and read off the contraction. But this does not scale too large expressions.
  2. Naively evaluate all unnested expressions with einsum: Iteratively, this will reduce the nesting until we're left with a single unnested expression. This constrains execution order. But this order is important for performance! Contraction optimization packages like opt_einsum can find the best contraction schedule and greatly improve performance. Check out this opt_einsum demo.

einsum_expand scales to arbitrarily complicated nested einsum expressions and in contrast to naive evaluation, its output can be optimized by opt_einsum and computed faster in many scenarios.


Toy example

I will use a slightly more involved example: Let A, B, C, D, E be matrices of matching dimensions and A = B @ C where C = D @ E:

B = rand(10, 5)
D = rand(5, 20)
E = rand(20, 30)

Let's first compute A naively with the following two einsum expressions:

C = einsum("ij,jk->ik", D, E)
A_naive = einsum("ij,jk->ik", B, C)

Combining the expressions into a single one, A = B @ D @ E,

A_expanded = einsum("ij,jk,kl->il", B, D, E)

yields the same result:

print(allclose(A_naive, A_expanded))

Naive evaluation

I claimed that expanding einsum expressions can lead to speed-ups when combined with contraction optimizers. Here is the naive evaluation scheme that we will use for comparison later.

def is_tensor(x):
    return isinstance(x, Tensor)

def is_nested(x):
    return isinstance(x, (tuple, list))

def is_equation(x):
    return isinstance(x, str)

def einsum_nested_naive(expression, optimize=False) -> Tensor:
    """Naively evaluate a nested einsum expression.

    Evaluate unnested expressions until no sub-expressions remains.

        expression: List describing a (potentially nested) einsum expression.
            The head is an equation, and the tail consists of tensors or more
            einsum expressions.
        optimize: If ``True``, use ``opt_einsum.contract`` to evaluate unnested
            expressions. Otherwise, use ``torch.einsum``. Default: ``False``.

        Result tensor.
    equation, operands = expression[0], expression[1:]

    if not is_equation(equation):
        raise ValueError(f"Invalid equation as first entry: {equation}.")

    operands_flat = []

    for op in operands:
        if is_tensor(op):
        elif is_nested(op):
            operands_flat.append(einsum_nested_naive(op, optimize=optimize))
            raise ValueError(f"Expect Tensor or einsum expression, got {op}.")

    contraction_func = contract if optimize else einsum

    return contraction_func(equation, *operands_flat)

Let's verify it on our toy example:

A = [
    ["ij,jk->ik", D, E],
A_nested_naive = einsum_nested_naive(A)

print(allclose(A_nested_naive, A_naive))

Helper functions

Working with einsum equations

To insert one sub-expression, we have to replace its output indices by group of input indices. For that we need to be able to split equations into per-operand groups and combine them back into one string. This is done by the following helpers:

def equation_to_groups(equation: str) -> List[str]:
    return re.split(",|->", equation)

def groups_to_equation(groups: List[str]) -> str:
    return "->".join([",".join(groups[:-1]), groups[-1]])

They are best understood by examples:

equation = "ij,jk,kl->il"
groups = equation_to_groups(equation)

print(f"Groups: {groups}")
print(f"Equation from groups: '{groups_to_equation(groups)}'")
Groups: ['ij', 'jk', 'kl', 'il']
Equation from groups: 'ij,jk,kl->il'

Renaming indices

The aforementioned replacement operation has to respect additional rules: As we're introducing the input group's indices, we need to make sure their names are not already occupied. We also need to match the output indices of the equation we want to insert.

Let's walk through our example to understand some issues we might encounter:

A = ["ij,jk->ik", B ["ij,jk->ik", C, D]]

We want to replace jk in A[0] by the sub-expression A[-1].

The latter currently has output indices ik. So we need to rename i into j. But j is already occupied by a summation index, so we need to rename that first. This will lead to a rename of




We can now safely expand the renamed

A = ["ij,jk->ik", B ["ja,ak->jk", C, D]]


["ij,ja,ak->ik", B, C, D]

In the last step we could have encountered another problem: If a was already taken in A[0], we would have to do another rename.

The renaming is done by the following helper rename_out_indices:

def rename_out_indices(equation, new_name, blocked=None):
    """Rename output indices into new_name, respecting blocked indices."""
    out_indices = equation_to_groups(equation)[-1]

    blocked = set() if blocked is None else blocked.copy()
    blocked = blocked.union(set(new_name)).union(set(out_indices))

    equation = _rename_indices(equation, blocked=blocked)

    out_indices = equation_to_groups(equation)[-1]

    for out_idx, new_idx in zip(out_indices, new_name):
        equation = equation.replace(out_idx, new_idx)

    return equation

def _rename_indices(equation, blocked=None):
    """Rename indices of an equation if they are blocked."""
    groups = equation_to_groups(equation)
    indices = set("".join(groups))

    blocked = set() if blocked is None else blocked.copy()
    blocked = blocked.union(indices)

    rename_indices = indices.intersection(blocked)

    for idx, new_name in zip(rename_indices, _get_free_characters(blocked=blocked)):
        equation = equation.replace(idx, new_name)

    return equation

def _get_free_characters(blocked=None):
    """Yield characters that are not in blocked, starting with a."""
    blocked = set() if blocked is None else blocked

    first = ord("a")
    shift = 0

    while True:
        char = chr(first + shift)
        if char not in blocked:
            yield char
        shift += 1

Here is how it acts on our example:

print(rename_out_indices("ij,jk->ik", "jk"))

We can block certain letters from being used, too:

print(rename_out_indices("ij,jk->ik", "jk", blocked={"c"}))

Inserting at a given position

We now have all the functionality to correctly rename and insert an expression.

def insert_at(expression, pos):
    """In-place insert the sub-expression at the given position."""
    equation, operands = expression[0], expression[1:]
    groups = equation_to_groups(equation)
    blocked = set("".join(groups))

    inner_equation, inner_operands = operands[pos][0], list(operands[pos][1:])
    inner_equation = rename_out_indices(inner_equation, groups[pos], blocked=blocked)
    inner_groups = equation_to_groups(inner_equation)

    out_operands = operands[:pos] + inner_operands + operands[pos + 1 :]
    out_groups = groups[:pos] + inner_groups[:-1] + groups[pos + 1 :]
    out_equation = groups_to_equation(out_groups)

    return [out_equation] + out_operands

Here are some examples:

  • Inserting an unnested sub-expression

    # strings instead of tensors for more readable printing
    expression = ["ij,jk->ik", "B", ["ij,jk->ik", "D", "E"]]
    print(insert_at(expression, 1))
    ['ij,ja,ak->ik', 'B', 'D', 'E']
  • Inserting a nested sub-expression:

    # strings instead of tensors for more readable printing
    expression = ["ij,jk->ik", "B", ["ij,jk->ik", ["ab,bk->ak", "D", "E"], "F"]]
    print(insert_at(expression, 1))
    print(insert_at(insert_at(expression, 1), 1))
    ['ij,ja,ak->ik', 'B', ['ab,bk->ak', 'D', 'E'], 'F']
    ['ij,jd,da,ak->ik', 'B', 'D', 'E', 'F']

Putting it together

Here is the einsum_expand helper I promised in the beginning:

def einsum_expand(expression):
    """Expand a nested list of ``einsum`` expressions into a single one.

    An ``einsum`` expression is defined via a tuple/list ``(equation, *operands)``
    where operands are either ``Tensor``s or ``einsum`` expressions.

        expression: List describing a (potentially nested) einsum expression.
            The head is an equation, and the tail consists of tensors or more
            einsum expressions.

        An unnested expression consisting of an equation and its operands.
    expression = list(expression)

    def expandable_pos(expression):
        for pos, subexpression in enumerate(expression[1:]):
            if is_nested(subexpression):
                return pos
        raise ValueError("No expandable expression found.")

    while any(is_nested(subexpression) for subexpression in expression[1:]):
        pos = expandable_pos(expression)
        expression = insert_at(expression, pos)

    return expression

Let's try it out on our toy example. We compare the result of naively contracting the nested expression with contracting the expanded expression with a usual einsum:

# einsum_expand handles list- and tuple-valued expressions
A = ("ij,jk->ik", B, ["ij,jk->ik", D, E])

A_naive = einsum_nested_naive(A)
A_expanded = einsum(*einsum_expand(A))

print(allclose(A_naive, A_expanded))

Use case

einsum_expand can be useful in combination with contraction optimizers like opt_einsum. To demonstrate this, we consider a modified version of this opt_einsum demo. But I replace the matrix C by the matrix product A @ B to make the expression nested:

dim = 30
inner_dim = dim // 10

I = rand(dim, dim, dim, dim)
A = rand(dim, inner_dim)
B = rand(inner_dim, dim)
C = ["ij,jk->ik", A, B]

expression = ["pi,qj,ijkl,rk,sl->pqrs", C, C, I, C, C]

We will benchmark different combinations of contraction implementations (opt_einsum.contract versus torch.einsum) and evaluation schedules (naive versus expand). Here are their definitions:

def naive_torch():
    """Naively evaluate with `torch.einsum`."""
    return einsum_nested_naive(expression, optimize=False)

def naive_opt_einsum():
    """Naively evaluate with `opt_einsum.contract`."""
    return einsum_nested_naive(expression, optimize=True)

# NOTE Left out for faster execution
# def expand_torch():
#     """Expand expression then evaluate with `torch.einsum`."""
#     return einsum(*einsum_expand(expression))

def expand_opt_einsum():
    """Expand expression then evaluate with `opt_einsum.contract`."""
    return contract(*einsum_expand(expression))

Let's make sure they produce the same result:

    allclose(naive_torch(), naive_opt_einsum())
    # NOTE Left out for faster execution
    # and allclose(naive_torch(), expand_torch())
    and allclose(naive_torch(), expand_opt_einsum())

Let's compare their performance:


naive_torch_time = min(repeat(naive_torch, repeat=REPEAT, number=NUMBER))
# NOTE Left out for faster execution
# expand_torch_time = min(repeat(expand_torch, repeat=REPEAT, number=NUMBER))
naive_opt_einsum_time = min(repeat(naive_opt_einsum, repeat=REPEAT, number=NUMBER))
expand_opt_einsum_time = min(repeat(expand_opt_einsum, repeat=REPEAT, number=NUMBER))

best_overall = min(
    # NOTE Left out for faster execution
    # expand_torch_time,

    f"Naive  + torch.einsum        : {naive_torch_time:.5f} s "
    + f"(x {naive_torch_time / best_overall:.2f})"
# NOTE Left out for faster execution
# print(
#     f"Expand + torch.einsum        : {expand_torch_time:.5f} s "
#     + f"(x {expand_torch_time / best_overall:.2f})"
# )
    f"Naive  + opt_einsum.contract : {naive_opt_einsum_time:.5f} s "
    + f"(x {naive_opt_einsum_time / best_overall:.2f})"
    f"Expand + opt_einsum.contract : {expand_opt_einsum_time:.5f} s "
    + f"(x {expand_opt_einsum_time / best_overall:.2f})"
Naive  + torch.einsum        : 0.20648 s (x 4.97)
Naive  + opt_einsum.contract : 0.07125 s (x 1.71)
Expand + opt_einsum.contract : 0.04157 s (x 1.00)

We can see that combining einsum_expand with opt_einsum performs best.

That's all for now. Enjoy!

Author: Felix Dangel

Created: 2022-04-02 Sat 15:28