How Operator Fusion Affects Error
I have recently become a big fan of the condition number formalism for
floating-point error; it is a simple mental model that explains a lot
about floating point. In this blog post I want to use that mental
model to investigate "operator fusion": basically, why library
functions like log1p
, cospi
, or hypot
improve error.1 [1 Actually, one
of the big effects of hypot
is to avoid overflow, which condition
numbers don't address, so, uhh, it's only part of the story.]
The condition number formalism
Here's the basic way the condition number formalism works. Every floating-point operation may introduce some rounding error. You can think of this as absolute rounding error or relative rounding error, there are variations of the formalism for either. But every floating-point operation may introduce some error.
Then, every time the result of one floating-point operation is the input to another floating-point operation—every edge in the data flow graph—that error is amplified (and potentially flips sign). Typically introduced relative error is very small, but these amplification factors can be very large, so the error can grow large enough to make the results incorrect.
There is also higher-order error in this formalism, we don't talk about it. I personally think some of my work with Artem on Rival suggests that you could treat higher-order error soundly using basically the same framework, but anyway, usually higher-order error is negligibly small.
Critically, the amplification factors are a property of the real-number function, so there's nothing you can do to change them. You can only affect the introduced errors (by using higher precision, or using different / better implementations of floating-point operators).
A worked example
In this formalism, you can analyze a floating-point expression like
log(1 + x)
by looking at all data flow paths through the expression. They are:
1
to+
tolog
x
to+
tolog
+
tolog
- just
log
Then the overall error is the sum of four terms:
intro(1) ampl(+, 1) ampl(log)
intro(x) ampl(+, 2) ampl(log)
intro(+) ampl(log)
intro(log)
Some of these terms we know: intro(1)
is 0, because the constant 1.0
is exact in floating-point. So that term drops out. We might also have
bounds on some of them. For example, addition is correctly-rounded on
almost all hardware, so intro(+)
is half an ULP at most. And log
is
faithfully rounded in most libraries, so that's one ULP at most. So we
have a total error of:
intro(x) ampl(+, 2) ampl(log) + (½ ampl(log) + 1)
Let's look at the first term; it has an intro(x)
as well as two ampl
terms. Now, the definition of ampl
terms is really simple (it is the
logarithmic derivative of the function) and satisfies the chain rule,
so we can combine these ampl
terms into a single amplification factor
(for x
in log(1 + x)
, which we often call log1p
). So that gets us to:
intro(x) ampl(log1p) + (½ ampl(log) + 1)
Finally, we can look up the log
function and see that its
amplification factor is one over its output, which gets us to:
intro(x) ampl(log1p) + (1 + ½ / log(1 + x))
Generalizing a little bit, we see that analyzing the error of a function gets you:
- One term for each argument, multiplied by an amplification factor for the expression as a whole.
- Plus an "introduced" term which depends on the implementation of the function.
So, when we compose two atomic operations like +
and log
, we have a
composite function that has its own amplified and introduced errors.
Fused operators
Suppose we now look at one of these "fused" operations, like let's say
log1p(x)
. Why is it (often) more accurate than log(1 + x)
?
Well, let's write down its error model. There are two data flow paths through this smaller expression:
x
tolog1p
- just
log1p
The error model is then:
intro(x) ampl(log1p) + intro(log1p)
Note that the first term, the x
term, is the same as above. This is
because, recall, amplification terms do not depend on implementation
choices. But now look at the second term (the introduced term); this
term does depend on implementation choices, and specifically, using
log1p
is better than log(1 + x)
by exactly ½ / log(1 + x)
. For many
inputs this isn't a big win, but for x
near 0
, log(1 + x)
goes to 0
so
½ / log(1 + x)
goes to infinity. In those cases we reduce error
significantly.
So is the lesson that fused operations are always better? That can't
be right. For example, log1p(x - 1)
can be fused back into log(x)
;
they can't both be better than each other! In fact, there's no real
distinction between fused and unfused operations at all. So why is
error better?
It's because when we moved from log(1 + x)
to log1p(x)
, we removed one
internal rounding error. That rounding error was small (one ULP) but
it is amplified significantly in our expression, so removing it
reduces error by a lot.
But also keep in mind what this doesn't do. The amplification factor
of log1p
itself doesn't change, which means that error coming in to
log1p
is still amplified the same amount. And log1p
itself still
produces some error, just a lot less. So if later operations amplify
error by a lot, you'll still get a lot of error. More precisely, if
you have some large expression that uses log(1 + x)
or log1p
, then
using one or the other only affects one of potentially very many
sources of error in the expression. If you were already computing
log(1 + x)
, then sure, why not use log1p
. But typically, you need to
do some rewriting of the expression to even get it into the form where
a log1p
appears; in that case it's only a good idea if the rewriting
you do doesn't introduce other sources of error elsewhere.
Generalizing a bit, when we are implementing some expression E
, we
usually have many different options for how to implement it. Every
choice we make, we add an introduced error, and we can compute the
amplification factor for that error and add it to our total.
For example, suppose we want to implement the quadratic formula. Our
first decision is that we're going to compute d = sqrt(b^2 - 4 a c)
somewhere; the remaining part of the expression is (-b + d) / 2 a
, so
the introduced error in d
will be multiplied by its amplification
factor. Luckily, as is well-known, we can make sure d
and b
are always
opposite signs, division has low amplification, and multiplication by
2 and negation are exact, so this amplification factor is low. This
tells us that we've made a good choice computing d
as an intermediate
value. Now we can recurse and consider how to compute the two pieces
accurately.2 [2 Though, again, note that this formalism doesn't think
about overflow, which is also a big concern with the quadratic.]
Or, thinking about this bottom-up, I like to imagine that there are lots and lots of ways to get from your expression's inputs to its outputs. Some of those ways have high amplification factor and some have low amplification factor; you basically need to find a path from start to finish where all of the amplification factors along the way are small.
Conclusion
The condition number model of floating-point error is a composable way
to think about floating-point error. In this model, you have two basic
ways to reduce error: either using higher precision / more accurate
operators, which reduces introduced error, or approaching your problem
in a different way, so as to end up with different amplification
factors. One specific thing you can do is use fused operators like
log1p
, which, by fusing operations, might avoid an introduced error
that is highly amplified. But while definitely helps in some examples,
it's really a question of how much you need to rearrange the
expression to use the fused operator and whether or not that
rearrangement introduces more error than the fused operator saves.