Pavel Panchekha

By

Share under CC-BY-SA.

How Operator Fusion Affects Error

I have recently become a big fan of the condition number formalism for floating-point error; it is a simple mental model that explains a lot about floating point. In this blog post I want to use that mental model to investigate "operator fusion": basically, why library functions like log1p, cospi, or hypot improve error.1 [1 Actually, one of the big effects of hypot is to avoid overflow, which condition numbers don't address, so, uhh, it's only part of the story.]

The condition number formalism

Here's the basic way the condition number formalism works. Every floating-point operation may introduce some rounding error. You can think of this as absolute rounding error or relative rounding error, there are variations of the formalism for either. But every floating-point operation may introduce some error.

Then, every time the result of one floating-point operation is the input to another floating-point operation—every edge in the data flow graph—that error is amplified (and potentially flips sign). Typically introduced relative error is very small, but these amplification factors can be very large, so the error can grow large enough to make the results incorrect.

There is also higher-order error in this formalism, we don't talk about it. I personally think some of my work with Artem on Rival suggests that you could treat higher-order error soundly using basically the same framework, but anyway, usually higher-order error is negligibly small.

Critically, the amplification factors are a property of the real-number function, so there's nothing you can do to change them. You can only affect the introduced errors (by using higher precision, or using different / better implementations of floating-point operators).

A worked example

In this formalism, you can analyze a floating-point expression like log(1 + x) by looking at all data flow paths through the expression. They are:

  • 1 to + to log
  • x to + to log
  • + to log
  • just log

Then the overall error is the sum of four terms:

  • intro(1) ampl(+, 1) ampl(log)
  • intro(x) ampl(+, 2) ampl(log)
  • intro(+) ampl(log)
  • intro(log)

Some of these terms we know: intro(1) is 0, because the constant 1.0 is exact in floating-point. So that term drops out. We might also have bounds on some of them. For example, addition is correctly-rounded on almost all hardware, so intro(+) is half an ULP at most. And log is faithfully rounded in most libraries, so that's one ULP at most. So we have a total error of:

intro(x) ampl(+, 2) ampl(log) + (½ ampl(log) + 1)

Let's look at the first term; it has an intro(x) as well as two ampl terms. Now, the definition of ampl terms is really simple (it is the logarithmic derivative of the function) and satisfies the chain rule, so we can combine these ampl terms into a single amplification factor (for x in log(1 + x), which we often call log1p). So that gets us to:

intro(x) ampl(log1p) + (½ ampl(log) + 1)

Finally, we can look up the log function and see that its amplification factor is one over its output, which gets us to:

intro(x) ampl(log1p) + (1 + ½ / log(1 + x))

Generalizing a little bit, we see that analyzing the error of a function gets you:

  • One term for each argument, multiplied by an amplification factor for the expression as a whole.
  • Plus an "introduced" term which depends on the implementation of the function.

So, when we compose two atomic operations like + and log, we have a composite function that has its own amplified and introduced errors.

Fused operators

Suppose we now look at one of these "fused" operations, like let's say log1p(x). Why is it (often) more accurate than log(1 + x)?

Well, let's write down its error model. There are two data flow paths through this smaller expression:

  • x to log1p
  • just log1p

The error model is then:

intro(x) ampl(log1p) + intro(log1p)

Note that the first term, the x term, is the same as above. This is because, recall, amplification terms do not depend on implementation choices. But now look at the second term (the introduced term); this term does depend on implementation choices, and specifically, using log1p is better than log(1 + x) by exactly ½ / log(1 + x). For many inputs this isn't a big win, but for x near 0, log(1 + x) goes to 0 so ½ / log(1 + x) goes to infinity. In those cases we reduce error significantly.

So is the lesson that fused operations are always better? That can't be right. For example, log1p(x - 1) can be fused back into log(x); they can't both be better than each other! In fact, there's no real distinction between fused and unfused operations at all. So why is error better?

It's because when we moved from log(1 + x) to log1p(x), we removed one internal rounding error. That rounding error was small (one ULP) but it is amplified significantly in our expression, so removing it reduces error by a lot.

But also keep in mind what this doesn't do. The amplification factor of log1p itself doesn't change, which means that error coming in to log1p is still amplified the same amount. And log1p itself still produces some error, just a lot less. So if later operations amplify error by a lot, you'll still get a lot of error. More precisely, if you have some large expression that uses log(1 + x) or log1p, then using one or the other only affects one of potentially very many sources of error in the expression. If you were already computing log(1 + x), then sure, why not use log1p. But typically, you need to do some rewriting of the expression to even get it into the form where a log1p appears; in that case it's only a good idea if the rewriting you do doesn't introduce other sources of error elsewhere.

Generalizing a bit, when we are implementing some expression E, we usually have many different options for how to implement it. Every choice we make, we add an introduced error, and we can compute the amplification factor for that error and add it to our total.

For example, suppose we want to implement the quadratic formula. Our first decision is that we're going to compute d = sqrt(b^2 - 4 a c) somewhere; the remaining part of the expression is (-b + d) / 2 a, so the introduced error in d will be multiplied by its amplification factor. Luckily, as is well-known, we can make sure d and b are always opposite signs, division has low amplification, and multiplication by 2 and negation are exact, so this amplification factor is low. This tells us that we've made a good choice computing d as an intermediate value. Now we can recurse and consider how to compute the two pieces accurately.2 [2 Though, again, note that this formalism doesn't think about overflow, which is also a big concern with the quadratic.]

Or, thinking about this bottom-up, I like to imagine that there are lots and lots of ways to get from your expression's inputs to its outputs. Some of those ways have high amplification factor and some have low amplification factor; you basically need to find a path from start to finish where all of the amplification factors along the way are small.

Conclusion

The condition number model of floating-point error is a composable way to think about floating-point error. In this model, you have two basic ways to reduce error: either using higher precision / more accurate operators, which reduces introduced error, or approaching your problem in a different way, so as to end up with different amplification factors. One specific thing you can do is use fused operators like log1p, which, by fusing operations, might avoid an introduced error that is highly amplified. But while definitely helps in some examples, it's really a question of how much you need to rearrange the expression to use the fused operator and whether or not that rearrangement introduces more error than the fused operator saves.

Footnotes:

1

Actually, one of the big effects of hypot is to avoid overflow, which condition numbers don't address, so, uhh, it's only part of the story.

2

Though, again, note that this formalism doesn't think about overflow, which is also a big concern with the quadratic.