Pavel Panchekha

By

Share under CC-BY-SA.

Detecting Symmetric Expressions

As I work on Herbie, my numerical analysis assistant, I'm always thinking about how I could teach it classic numerical analysis tricks. Recently, my student Oliver Flatt and I have been thinking about tricks for symmetric expressions.

Symmetry for Accuracy

Consider the following computation, which is important in statistics:

\[ \log(e^a + e^b) \]

The numerical analysis wisdom is that this expression is best evaluated via the expression

max(a, b) + log1p(exp(min(a, b) - max(a, b)))

There's a lot going on here (law of logarithms, the log1p function, and so on) but I want to focus on min and max. The key is that the original expression, in a and b, has turned into an expression in min(a, b) and max(a, b). Those are the same values as a and b, though which is which depends on the input. And this is similar to another well-known trick for computing the area of a triangle with side lengths A, B, and C:

[A, B, C] = sort([A, B, C])
sqrt((C + (B + A)) * (A - (C - B)) * (A + (C - B)) * (C + (B - A)))

There's a similar trick for computing the triangle's interior angles, which requires sorting the two side lengths the angle joins.

It'd be way cool if Herbie could find these tricks automatically!

A Proposed Approach

Traditionally, Herbie deals with cases like this by adding branches, using an algorithm we call regime inference. But that algorithm basically requires treating each branch as its own special case, and while with two variables there are only two branches, with three variables there are already six, which is more than Herbie can do well. But symmetric expressions don't need to treat each branch as a special case; they are all more or less the same case, with the variables names swapped around.

Herbie works based on sampled input points. If we knew that the expression was symmetric in variables A, B, and C, we could just take all the sampled inputs and sort A, B, and C. Then the rest of Herbie would operate on a single case (where the variables come sorted), and at the end we'd take Herbie's results and tack on a sort in front to handle the other cases. Basically nothing in Herbie's internal structure would need to be disturbed!

All we need is a way to detect when an expression is symmetric in a set of variables.

Detecting symmetries

This turns out harder than expected. Sometimes, all variables in an expression can be sorted, like in a + b + c. Other times, only a subset can: in a * b + c, it's fine to sort a and b but c must stay itself. Other times, there are two disjoint sets of sortable variables, like in pow(a * d, b + c). But there can also be strange cases, like a^3 b^2 c + b^3 c^2 a + c^3 a^2 b, where you can rotate a, b, and c but you can't sort them. Or what about a * d + b * c, where you can swap a and d, or b and c, and you can also swap the pairs, but you still can't arbitrarily sort all four variables?

To abstract a bit, for any expression E[x_1, ..., x_n], some permutations of its variables leave the expression's value unchanged. Those permutations can be composed—they form a group—but in general they are some arbitrary subgroup \(G\) of the full symmetric group \(S_n\) on the free variables of E. We want to find the largest product of symmetric subgroups that is contained in \(G\):

\[ \left( \prod_i S_{n_i} \right) < G < S_n \]

First, it's important to note that this largest product is unique.

A symmetric group is generated by its pairwise swaps, and if both \((a, b)\) and \((b, c)\) can be swapped, so can \((a, c)\), via the sequence \((a, b) (b, c) (a, b)\) familiar to computer scientists as the xor swapping trick. Transitivity suggests a graph, so make a graph where each variable is a node, and an edge between two variables means they can be swapped without affecting the expression's value. Transitivity guarantees that this graph is the union of disjoint cliques, which corresponds to that largest product.

So how do we find it?

Well, the swap generators suggest a strategy. First, consider each "swapped" form of the expression, with two variables swapped. Test each swapped form for equality to the unswapped form; build a graph of equal ones. In fact, it's not strictly required to test every pair: the connected components of the graph fill in "missing" swaps that you didn't bother testing.

And in Herbie, we already incorporate a powerful algorithm, e-graphs, for proving terms equal to one another. In fact, e-graphs can prove several terms equal to one another in parallel. To use e-graphs for this, we create an e-graph with the original expression in it, plus every swapped version. The e-graph is then iterated, which means it uses a bunch of rewrite rules to prove various expressions and subexpressions equal. After it's done, we just need to see which swapped versions ended up equal to the original version, build the graph, compute connected components, and we're done. If we're feeling cheeky, we can even compute connected components using the e-graph structure, though in practice we're talking about 0-20 variables so that isn't really necessary.

Initial Results

I implemented this algorithm in the new symmetry branch in Herbie and it seems to work; I ran it on the Herbie benchmark suite and it found 80 benchmarks with non-trivial symmetries, including one (Linear.V4:$cdot from linear-1.19.1.3, C) with four separate groups of sortable variables. Oliver and I will be looking into integrating this into Herbie in the near future.