Synthesizing Range Reductions
As part of Ian and my work on synthesizing math libraries, I've been thinking about synthesizing range reductions. I think we've found a beautiful and efficient way to do this tricky bit of synthesis using e-graphs. This idea is also related to recent Herbie work on detecting symmetric expressions.
What is range reduction
The core of modern library function implementation is the Remez algorithm, which approximates a function f
over an interval [a, b]
with polynomials.1 [1 When certain conditions are met, the Remez algorithm finds the polynomial approximation with minimal worst-case error, to an arbitrary level of accuracy.] But typically when you are actually implementing a mathematical function, you want it to work for an arbitrary real-number input, not for some narrow range [a, b]
. Moreover, the wider the interval [a, b]
, the worse the approximation error that the Remez algorithm provides.
Range reduction refers to the algebraic tricks used to circumvent this problem. For example, consider implementing log
2 [2 Keep in mind that in my work I am more interested in implementing functions that typically don't have pre-existing library implementations. But it's easier to give examples with well-known functions!] for arbitrary real inputs. As is well-known, log
turns multiplications into additions, so we can implement:
def log(x): m, e = math.frexp(x) return m * LOG2 + log_rr(m)
Here, log_rr
only ever receives inputs in [0.5, 1]
, so it can be implemented with Remez exchange. The LOG2
constant can be precomputed (in high precision if needed) and frexp
can just use bit tricks.
However, there's a big design space here. For example, if we subtract 0.5
from e
and multiple m
by the square root of two, we end up with implementing log_rr
on the interval [sqrt(2)/2, sqrt(2)]
. Now if we'd like we can apply the identity log(1/x) = -log(x)
to further shrink the interval to [1, sqrt(2)]
. This could be good or bad. The interval is smaller, which ought to make the Remez approximation more accurate—or faster, since we can trade accuracy for speed. But on the other hand, the extra multiplication and division steps introduce error.
The range reduction steps available depend a lot on the mathematical function in question. Most implementations of sin
leverage sin(x + 2pi) = sin(x)
, sin(pi - x) = sin(x)
, and sin(-x) = -sin(x)
. Implementations of tan
leverage tan(x + pi) = tan(x)
and tan(-x) = -tan(x)
but some also leverage tan(pi/2 - x) = 1/tan(x)
. It depends a lot on your accuracy and performance goals, and contextual information like the range of inputs you ultimately want to support and how fast operations like table lookup and division are.
The challenge, then, is to automatically synthesize range reduction strategies for arbitrary functions. Ideally, we want to synthesize all possible range reduction strategies, in cases where there are more than one, so that we can instantiate each one and compare them.
The monoid model
The core of our approach is to model a range reduction strategy group-theoretically. Specifically, we are interested in pairs of range reduction and reconstruction functions \(s\) and \(t\) such that:
\[ t(f(s(x))) = f(x) \]
Here \(s\) does the reduction and \(t\) does the reconstruction. However, instead of thinking about a single \((s, t)\) pair, let's think about a family of these pairs.3 [3 This was a major stumbling block for us for about a month.] The full process for evaluating f(x)
will be something like this:
- Examine
x
and, based on it, choose a pair of \(s_x\) and \(t_x\) - Apply \(s_x\)
- Apply
f
- Apply \(t_x\)
Note that the \(s_x\) and \(t_x\) here are related, so there is some "state" that can be passed between them—state like whether the output needs to be negated or how many multiples of 2 were stripped off the input.
Let's go back to the \((s, t)\) pairs. The point of thinking of this as a family of operations is that we can now add internal structure to this family. Most importantly, these pairs form a monoid, composed like so:
\[ t_2(t_1(f(s_1(s_2(x))))) = t_2(f(s_2(x))) = f(x) \]
So we can now think of our range reduction strategy as returning:
- A reduction monoid \(M\) containing sound \((s, t)\) pairs and some simple internal structure
- A map \(g : \mathbb{R} \to M\) that selects specific \(s_x\) and \(t_x\) for each input \(x\)
We want two properties to be true of our \(s_x\) and \(t_x\) pairs. First, there is a soundness property, \(\forall x, t_x(f(s_x(x))) = f(x)\). Second, there is a completeness property, which is that there's some small interval \(I\) such that \(\forall x, s_x(x) \in I\). This second property guarantees that the Remez algorithm will successfully synthesize a good implementation of f
.
Some examples
Here are some examples of \((s, t)\) pairs, presented as equations:
f(x) = f(x + p)
, for periodic functionsf(x) = f(-x)
, for even functionsf(x) = -f(-x)
, for odd functionsf(x) = f(x/2) + c
, for logarithm-like functionsf(x) = 1/f(p - x)
, for functions like tangentf(x, y) = f(y, x)
, for symmetric functionsf(x, y) = -f(y, x)
, for anti-symmetric functions
For each of them you might maps \(g\):
- For periodic functions, use Cody-Waite range reduction to compute
x
modulop
- For even functions, take the absolute value of the input
- For odd functions, split off the sign bit and re-attach at the end
- For logarithm-like functions, use
frexp
to reduce the input to[0.5, 1]
- For tangent-like functions, reduce the input range in half
- For symmetic functions, sort the inputs
- For anti-symmetric functions, sort the inputs but count swaps and use the parity
Synthesizing the monoid
Our next step is to synthesize the monoid \(M\). The key here is to leverage the monoid structure.
First, let's find some arbitrary pairs \(s\) and \(t\). To do so, create an e-graph containing f(x)
and iterate it with a bunch of pre-existing rules. Those rules would have identities for the basic functions, possibly synthesized by something like Ruler.
After a bunch of iteration, the e-graph ought to record equalities like t(f(s(x))) = f(x)
. We now need to efficiently find those equalities. To do so, extract the simplest expression from each enode, with one twist: enodes that match the pattern f
get cost 0
. In other words, extract the simplest expression, ignoring the cost of any expressions that are equal to f(?)
. Now example all enodes in the eclass of f(x)
. These should contain all expressions t(f(s(x)))
, including of course the identity transformation f(x)
.
Now that we have a bunch of (s, t)
pairs, we next want to consider the monoid \(M\) generated by them. However, in general, this monoid is pretty complicated—it is a free monoid with reduced representations that contain arbitrarily many (s, t)
pairs in sequence. We're not interested in math function implementations that contain loops. We want \(M\) to have some simple representation.
To make things concrete, let's require that \(M\) contain only the transformations \(g_1^{k_1} g_2^{k_2} \dotsb g_n^{k_n}\), where the \(g_i\) are the generating (s, t)
pairs, in some fixed order, while the \(k_i\) are integers. Then the \(g\) function could take in an input \(x\) and return a list of integers describing an element of \(M\).
To avoid loops, we first need to be able to apply \(s^k(x)\) for any generator \(s\). This is just mathematical expression in two variables \(s(x, k)\). To synthesize this expression, we use e-graph intersection. Specifically, for a concrete natural number \(i\), construct an egraph \(G_i\) containing:
- The equality \(y = s(s(\dotsb s(x)))\), with \(k\) applications of \(i\)
- The equality \(k = i\), where
k
is a variable expression
Iterate this egraph; eventually, it ought to contain the expression \(y = s(x, k)\), no matter the integer \(i\). Therefore, if we intersect a bunch of these e-graphs together (but practice probably a few will suffice) we will lose all the speciifc equalities like \(y = s(s(x))\) but we will keep the generic equality \(y = s(x, k)\). Do the same for \(t\).
Now, with multiple generators, there is still an infinite number of words \(g_i^{k_i} g_j^{k_j} \dotsb\) that we could use for range reduction. We can search through them all to find good ones—or, we can reduce the search space using commutators. For two transformations \(s_1\) and \(s_2\), we need to synthesize equalities like:
\[ \forall x, s_2(s_1(x)) = s_1^a(s_2^b(x)) \]
When we have more than two transformations, this gets a little tricker. For example, with \(s_1\), \(s_2\), and \(s_3\), we might need an equality like:
\[ \forall x, s_3(s_1(x)) = s_1^a(s_3^b(s_2^c(x))) \]
Furthermore, there is some rule where not every left hand pair can generate every possible generator on the right, to guarantee termination.
Luckily, in practice we expect relatively few generators and also for the \((a, \dotsc, c)\) values to be small, so worst case we can enumerate the possible orders.
It's also possible that there may be multiple monoids \(M\) that can be generated this way. At the very least, the order in which the various generators are applied is important. For example, if a function is both periodic and also odd (like sin
), applying periodicity first and then oddness reduces the input to a half-period, while applying oddity first and periodicity second can only reduce the input to a full period.
In many cases, the monoid is a group, such as for the transformation f(x) = f(x + p)
. Also in many cases, generators have finite order, like f(x) = -f(-x)
. Ideally we would automatically detect these cases as well.
Synthesizing the monoid map
The structure we've chosen for \(M\) makes it easier to synthesize \(g\). For example, suppose we have a function that is both odd and periodic, with generators \(n\) (for negate) and \(p\) (for periodic). Our function \(g\) basically needs to look at an input \(x\) and determine how many periods to slide it by and whether to negate it.
Now, suppose we've adopted a representation of \(M\) as \(p^k n^l\), so that we first slide the input around and then negate it. In this case, our input comes in as an arbitrary real number; then the sliding can reduce it to the range \([-p/2, p/2]\), and the negation can further reduce it to \([0, p/2]\). We can consider these two steps in isolation. First, we need to figure out how we can reduce an arbitrary input by sliding with some \(g_p\). Then, we need to figure out which inputs need to be negated with some \(g_n\).
One challenge to this is how arbitrary it is. Sliding can turn an input into any interval of length \(p\). Similarly, why reduce to \([0, p/2]\) instead of \([-p/2, 0]\)?
Also, these \(g_p\) and \(g_n\) functions map arbitrary real inputs to integers. How do we synthesize them? One option is to search a pre-existing library of functions for those that are invariant under the \(s\) transformation, like fmod
for sliding and fabs
for negation. But that doesn't give us access to the k
value we are looking for. Also, what if we don't have enough pre-existing functions?
This part of the process is an open question.
Conclusion
This monoidal model for range reduction provides a clear framework for searching for and implementing range reductions. Moreover, at least half of it seems synthesizable. (I'm not sure if the second half is synthesizable, but probably with some caveats or hacks it is.) The hope is that this can therefore provide crucial assistance to math function implementation strategies. It could also be useful in Herbie, where reducing the input space sometimes help identify globally-useful rewrites.
Footnotes:
When certain conditions are met, the Remez algorithm finds the polynomial approximation with minimal worst-case error, to an arbitrary level of accuracy.
Keep in mind that in my work I am more interested in implementing functions that typically don't have pre-existing library implementations. But it's easier to give examples with well-known functions!
This was a major stumbling block for us for about a month.