Improving Rust with Herbie
Last week I had a fun experience using Herbie to actually improve the numerical accuracy of real code.
Discovering a Rust problem
It all started when Finn (an undergraduate from UW) submitted a PR to
FPBench to compile FPBench benchmarks to Rust. Besides being great
work and a valuable contribution, he also pointed out that our test
framework was failing on his FPBench to Rust compiler "due to
precision issues with Rust's implementation" of the asinh
and acosh
functions.
This intrigued me, because I had assumed that Rust would just link to the system math library. Indeed, it does do that for most functions, but for the arc-hyperbolic-trigonometric functions Rust uses its own custom code.1 [1 Maybe there's a platform support issue?] That implementation looks like this:
pub fn asinh(self) -> f64 { (self.abs() + ((self * self) + 1.0).sqrt()).ln().copysign(self) }
In case you're not Rust-fluent, this basically corresponds to the dictionary definition:2 [2 Plus the right logic for handling signs.]
\[ \operatorname{sinh}^{-1}(x) = \log(x + \sqrt{x^2 + 1}) \]
Yet this is not a very accurate equation. I threw it into Herbie to see what it would say, and got this graph:
Focus on the red line here (the green line is my fix): the formula is very inaccurate for values bigger than about 10150 or smaller than 1 in absolute value. Let's look at why.
Large values
The issue with large values, bigger than 10150, is pretty simple.
Basically, we're computing x * x
, and if x
is bigger than about
10150, that will overflow and return infinity. Then we add that to
one, take a square root, add x
, and take the logarithm, and the whole
time the answer is also infinity. Which isn't close to right! This
kind of issue is pretty common, and in a lot of code you might not
care about it, but for a library function it's no good. That said,
this problem was pretty obvious to me just from looking at the
equation, and the fix was simple too.
The key here is that while x * x
is indeed too big to represent as a
floating-point number, we're going to take a square root a little
later, and that square root is representable. So we need to find a
better way to do the square root. Luckily, standard math libraries
include a function called hypot
which computes the square root sum of
squares without intermediate overflow. So we can just replace:
sqrt(x * x + 1) -> hypot(x, 1)
By the way, Herbie also suggests this fix, since it knows about hypot
.
This fixes the issue with large values.
Small values
The issue with small values, absolute values less than 1, was less obvious, and I didn't realize it would happen until I looked at Herbie's plot above.
Recall that we're computing log(x + sqrt(x*x + 1))
; let's trace what
happens for small values of x
. Here x*x + 1
is approximately 1, as is
the square root of that quantity. So now we're looking at x + 1
, which
again is close to 1 since x
is small. And then we're taking the
logarithm of a number close to 1. That's the issue!
What I mean is: logarithm is a function where inputs near 1 result in outputs near 0. But the issue is that its inputs near 1 are spaced pretty far apart—one epsilon, about 1e-16—while its outputs, near 0, are spaced much more finely, about 1e-300 apart.3 [3 Ignoring subnormals and some other caveats]. So while we should be able to provide about 1e-300 resolution, we can in fact only provide 1e-16 resolution, which is much worse. This general kind of problem exists any time you're using a function that has a root other than 0.
Here, Herbie offers a less useful solution: a Taylor series expansion.
While that works reasonably well for values close to 0, you've got to
be clever about picking where the Taylor series should take over, and
also introducing a branch condition is a bad idea in a fundamental
library, especially one marked #[inline]
, as this one is: it could
mess with vectorization or lead to worse code generation.
So let's look for a better way to solve it. Since this issue with log
is also somewhat common, libraries provide a helper function, log1p
,
which is log
but shifted left so the root is at 0. So we can rewrite:
log(...) -> log1p(... - 1)
On its own this doesn't actually fix the problem, but it makes the cancellation that implicitly happens inside logarithm into an explicit cancellation. Now we can work to solve that.
Here the fix is to move the -1
to be next to the sqrt
call, which is
what initially produces a value close to 1:
log(x + (sqrt(x*x + 1) - 1))
Now, with cancellations like this, between square roots (think of 1
as
sqrt(1)
) there is a classic trick that I remember from my work on the
quadratic formula: the difference of squares. Basically, when you have
sqrt(a) - sqrt(b)
, multiple this by sqrt(a) + sqrt(b)
over itself, and
use the difference of squares formula to simplify the numerator to
remove the square roots:
sqrt(x*x + 1) - 1 -> [ (x*x + 1) - 1 ] / [ sqrt(x*x + 1) + 1 ]
Now the numerator can be simplified to x * x
, which doesn't have a
cancellation, while the denominator now adds positive quantities, so
it doesn't have any cancellation:
x * x / (sqrt(x*x + 1) + 1)
Now we just need to remember our other problem, overflow. We can
remove the x * x
in the denominator using hypot
, but we've also now
got an x * x
in the numerator.
As before, while x * x
may overflow in the numerator, the denominator
will be similar in size to x
for large x
, so we're really computing
something like x * x / x
, which doesn't have to overflow. We just need
to avoid the intermediate x * x
.
One way to fix it is to divide the numerator and denominator by x
.
Then we have:
x * x / (sqrt(x*x + 1) + 1) -> x / (sqrt(x*x + 1)/x + 1/x) -> x / (sqrt(1 + 1/x*1/x) + 1/x) -> x / (hypot(1, 1/x) + 1/x)
Now the numerator is approximately size x
while the denominator is
size 1/x
, and nothing will overflow.
Another fix that I just came up with while writing this blog post is to zoom out and look at the full term:
log1p(x + x*x / (sqrt(x*x + 1) + 1))
We can factor an x
out from the argument to log1p
, thereby again
avoiding the cancellation:
log1p(x * (1 + x / (hypot(1, x) + 1)))
I'm not going to go back and contribute this simplification to Rust, but since it replaces a division by a multiply it is probably minutely faster than the version I contributed on some architectures.
What does Herbie think?
Of course any time I do any numerics work I think back to "would Herbie do as well". Here, it did OK but not great.
Noticing both problems went well. If I didn't have Herbie, I would have missed the small-values problem and end up with a much worse solution.
Fixing the large-values problem also went great. Herbie correctly
identified the hypot
-based solution. It's kind of easy, though.
However, fixing the small-values problem wasn't as great. Herbie suggested a Taylor series, which is both less elegant than my fix (branches, etc) but also less accurate near the branch. I'd like to look into whether Herbie can derive my version, or if not, why not.
Finally, for my "fixed" version, Herbie does a great job showing that the result is accurate and even identifying some subtle remaining issues. For example, let's look at the "final" version above:
log1p(x * (1 + x / (hypot(1, x) + 1)))
Note that we have an intermediate step where we multiply x
by some
value. If x
is very large, that value will be just under 2
, and for
super duper large values of x
, x * 2
can overflow.
There is a possible fix here. Let's expand out the log1p
for a second:
log1p(x * (1 + x / (hypot(1, x) + 1))) -> log(1 + x * (1 + x / (hypot(1, x) + 1))) -> log(x * (1/x + 1 + x / (hypot(1, x) + 1))) -> log(x) + log(1 + 1/x + x / (hypot(1, x) + 1)) -> log(x) + log1p(1/x + x / (hypot(1, x) + 1))
Now we're not multiplying x
by anything at all! But for very small
values of x
, we'll get cancellation between the log(x)
and log1p(1/x)
terms. If we add a branch at 1 we can get the best of both worlds, but
there's no way overflow at extremely large x
is a big enough issue to
deserve adding a branch here, and we're also now nearly doubling the
cost of asinh
by computing two different logarithms. So this subtle
issue found by Herbie doesn't strike me as worth fixing.
Contribution
Anyway, I packaged up my suggestions to asinh
, as well as somewhat
similar fixes to acosh
, and submitted them to Rust in #104548, and
then Max, who by the way is on the job market, implemented, tested,
and got it all merged into Rust in #104553. It's now in Rust nightlies
and will be released in the next Rust version!