Pavel Panchekha

By

Share under CC-BY-SA.

Improving Rust with Herbie

Last week I had a fun experience using Herbie to actually improve the numerical accuracy of real code.

Discovering a Rust problem

It all started when Finn (an undergraduate from UW) submitted a PR to FPBench to compile FPBench benchmarks to Rust. Besides being great work and a valuable contribution, he also pointed out that our test framework was failing on his FPBench to Rust compiler "due to precision issues with Rust's implementation" of the asinh and acosh functions.

This intrigued me, because I had assumed that Rust would just link to the system math library. Indeed, it does do that for most functions, but for the arc-hyperbolic-trigonometric functions Rust uses its own custom code.1 [1 Maybe there's a platform support issue?] That implementation looks like this:

pub fn asinh(self) -> f64 {
    (self.abs() + ((self * self) + 1.0).sqrt()).ln().copysign(self)
}

In case you're not Rust-fluent, this basically corresponds to the dictionary definition:2 [2 Plus the right logic for handling signs.]

\[ \operatorname{sinh}^{-1}(x) = \log(x + \sqrt{x^2 + 1}) \]

Yet this is not a very accurate equation. I threw it into Herbie to see what it would say, and got this graph:

202554752-9dcb0e0d-6a4f-4763-b7f3-278bfad89b1a.png

Focus on the red line here (the green line is my fix): the formula is very inaccurate for values bigger than about 10150 or smaller than 1 in absolute value. Let's look at why.

Large values

The issue with large values, bigger than 10150, is pretty simple. Basically, we're computing x * x, and if x is bigger than about 10150, that will overflow and return infinity. Then we add that to one, take a square root, add x, and take the logarithm, and the whole time the answer is also infinity. Which isn't close to right! This kind of issue is pretty common, and in a lot of code you might not care about it, but for a library function it's no good. That said, this problem was pretty obvious to me just from looking at the equation, and the fix was simple too.

The key here is that while x * x is indeed too big to represent as a floating-point number, we're going to take a square root a little later, and that square root is representable. So we need to find a better way to do the square root. Luckily, standard math libraries include a function called hypot which computes the square root sum of squares without intermediate overflow. So we can just replace:

sqrt(x * x + 1) -> hypot(x, 1)

By the way, Herbie also suggests this fix, since it knows about hypot. This fixes the issue with large values.

Small values

The issue with small values, absolute values less than 1, was less obvious, and I didn't realize it would happen until I looked at Herbie's plot above.

Recall that we're computing log(x + sqrt(x*x + 1)); let's trace what happens for small values of x. Here x*x + 1 is approximately 1, as is the square root of that quantity. So now we're looking at x + 1, which again is close to 1 since x is small. And then we're taking the logarithm of a number close to 1. That's the issue!

What I mean is: logarithm is a function where inputs near 1 result in outputs near 0. But the issue is that its inputs near 1 are spaced pretty far apart—one epsilon, about 1e-16—while its outputs, near 0, are spaced much more finely, about 1e-300 apart.3 [3 Ignoring subnormals and some other caveats]. So while we should be able to provide about 1e-300 resolution, we can in fact only provide 1e-16 resolution, which is much worse. This general kind of problem exists any time you're using a function that has a root other than 0.

Here, Herbie offers a less useful solution: a Taylor series expansion. While that works reasonably well for values close to 0, you've got to be clever about picking where the Taylor series should take over, and also introducing a branch condition is a bad idea in a fundamental library, especially one marked #[inline], as this one is: it could mess with vectorization or lead to worse code generation.

So let's look for a better way to solve it. Since this issue with log is also somewhat common, libraries provide a helper function, log1p, which is log but shifted left so the root is at 0. So we can rewrite:

log(...) -> log1p(... - 1)

On its own this doesn't actually fix the problem, but it makes the cancellation that implicitly happens inside logarithm into an explicit cancellation. Now we can work to solve that.

Here the fix is to move the -1 to be next to the sqrt call, which is what initially produces a value close to 1:

log(x + (sqrt(x*x + 1) - 1))

Now, with cancellations like this, between square roots (think of 1 as sqrt(1)) there is a classic trick that I remember from my work on the quadratic formula: the difference of squares. Basically, when you have sqrt(a) - sqrt(b), multiple this by sqrt(a) + sqrt(b) over itself, and use the difference of squares formula to simplify the numerator to remove the square roots:

sqrt(x*x + 1) - 1 -> [ (x*x + 1) - 1 ] / [ sqrt(x*x + 1) + 1 ]

Now the numerator can be simplified to x * x, which doesn't have a cancellation, while the denominator now adds positive quantities, so it doesn't have any cancellation:

x * x / (sqrt(x*x + 1) + 1)

Now we just need to remember our other problem, overflow. We can remove the x * x in the denominator using hypot, but we've also now got an x * x in the numerator.

As before, while x * x may overflow in the numerator, the denominator will be similar in size to x for large x, so we're really computing something like x * x / x, which doesn't have to overflow. We just need to avoid the intermediate x * x.

One way to fix it is to divide the numerator and denominator by x. Then we have:

x * x / (sqrt(x*x + 1) + 1)
-> x / (sqrt(x*x + 1)/x + 1/x)
-> x / (sqrt(1 + 1/x*1/x) + 1/x)
-> x / (hypot(1, 1/x) + 1/x)

Now the numerator is approximately size x while the denominator is size 1/x, and nothing will overflow.

Another fix that I just came up with while writing this blog post is to zoom out and look at the full term:

log1p(x + x*x / (sqrt(x*x + 1) + 1))

We can factor an x out from the argument to log1p, thereby again avoiding the cancellation:

log1p(x * (1 + x / (hypot(1, x) + 1)))

I'm not going to go back and contribute this simplification to Rust, but since it replaces a division by a multiply it is probably minutely faster than the version I contributed on some architectures.

What does Herbie think?

Of course any time I do any numerics work I think back to "would Herbie do as well". Here, it did OK but not great.

Noticing both problems went well. If I didn't have Herbie, I would have missed the small-values problem and end up with a much worse solution.

Fixing the large-values problem also went great. Herbie correctly identified the hypot-based solution. It's kind of easy, though.

However, fixing the small-values problem wasn't as great. Herbie suggested a Taylor series, which is both less elegant than my fix (branches, etc) but also less accurate near the branch. I'd like to look into whether Herbie can derive my version, or if not, why not.

Finally, for my "fixed" version, Herbie does a great job showing that the result is accurate and even identifying some subtle remaining issues. For example, let's look at the "final" version above:

log1p(x * (1 + x / (hypot(1, x) + 1)))

Note that we have an intermediate step where we multiply x by some value. If x is very large, that value will be just under 2, and for super duper large values of x, x * 2 can overflow.

There is a possible fix here. Let's expand out the log1p for a second:

log1p(x * (1 + x / (hypot(1, x) + 1)))
-> log(1 + x * (1 + x / (hypot(1, x) + 1)))
-> log(x * (1/x + 1 + x / (hypot(1, x) + 1)))
-> log(x) + log(1 + 1/x + x / (hypot(1, x) + 1))
-> log(x) + log1p(1/x + x / (hypot(1, x) + 1))

Now we're not multiplying x by anything at all! But for very small values of x, we'll get cancellation between the log(x) and log1p(1/x) terms. If we add a branch at 1 we can get the best of both worlds, but there's no way overflow at extremely large x is a big enough issue to deserve adding a branch here, and we're also now nearly doubling the cost of asinh by computing two different logarithms. So this subtle issue found by Herbie doesn't strike me as worth fixing.

Contribution

Anyway, I packaged up my suggestions to asinh, as well as somewhat similar fixes to acosh, and submitted them to Rust in #104548, and then Max, who by the way is on the job market, implemented, tested, and got it all merged into Rust in #104553. It's now in Rust nightlies and will be released in the next Rust version!

Footnotes:

1

Maybe there's a platform support issue?

2

Plus the right logic for handling signs.

3

Ignoring subnormals and some other caveats