# Improving Rust with Herbie

Last week I had a fun experience using Herbie to actually improve the numerical accuracy of real code.

## Discovering a Rust problem

It all started when Finn (an *undergraduate* from UW) submitted a PR to
FPBench to compile FPBench benchmarks to Rust. Besides being great
work and a valuable contribution, he also pointed out that our test
framework was failing on his FPBench to Rust compiler "due to
precision issues with Rust's implementation" of the `asinh`

and `acosh`

functions.

This intrigued me, because I had assumed that Rust would just link to
the system math library. Indeed, it does do that for most functions,
but for the arc-hyperbolic-trigonometric functions Rust uses its own
custom code.^{1} [^{1} Maybe there's a platform support issue?] That
implementation looks like this:

pub fn asinh(self) -> f64 { (self.abs() + ((self * self) + 1.0).sqrt()).ln().copysign(self) }

In case you're not Rust-fluent, this basically corresponds to the
dictionary definition:^{2} [^{2} Plus the right logic for handling signs.]

\[ \operatorname{sinh}^{-1}(x) = \log(x + \sqrt{x^2 + 1}) \]

Yet this is not a very accurate equation. I threw it into Herbie to see what it would say, and got this graph:

Focus on the red line here (the green line is my fix): the formula is
very inaccurate for values bigger than about 10^{150} or smaller than 1
in absolute value. Let's look at why.

## Large values

The issue with large values, bigger than 10^{150}, is pretty simple.
Basically, we're computing `x * x`

, and if `x`

is bigger than about
10^{150}, that will overflow and return infinity. Then we add that to
one, take a square root, add `x`

, and take the logarithm, and the whole
time the answer is also infinity. Which isn't close to right! This
kind of issue is pretty common, and in a lot of code you might not
care about it, but for a library function it's no good. That said,
this problem was pretty obvious to me just from looking at the
equation, and the fix was simple too.

The key here is that while `x * x`

is indeed too big to represent as a
floating-point number, we're going to take a square root a little
later, and that square root *is* representable. So we need to find a
better way to do the square root. Luckily, standard math libraries
include a function called `hypot`

which computes the square root sum of
squares without intermediate overflow. So we can just replace:

sqrt(x * x + 1) -> hypot(x, 1)

By the way, Herbie also suggests this fix, since it knows about `hypot`

.
This fixes the issue with large values.

## Small values

The issue with small values, absolute values less than 1, was less obvious, and I didn't realize it would happen until I looked at Herbie's plot above.

Recall that we're computing `log(x + sqrt(x*x + 1))`

; let's trace what
happens for small values of `x`

. Here `x*x + 1`

is approximately 1, as is
the square root of that quantity. So now we're looking at `x + 1`

, which
again is close to 1 since `x`

is small. And then we're taking the
logarithm of a number close to 1. That's the issue!

What I mean is: logarithm is a function where inputs near 1 result in
outputs near 0. But the issue is that its inputs near 1 are spaced
pretty far apart—one epsilon, about 1e-16—while its outputs, near
0, are spaced much more finely, about 1e-300 apart.^{3} [^{3} Ignoring
subnormals and some other caveats]. So while we *should* be able to
provide about 1e-300 resolution, we can in fact only provide 1e-16
resolution, which is *much* worse. This general kind of problem exists
any time you're using a function that has a root other than 0.

Here, Herbie offers a less useful solution: a Taylor series expansion.
While that works reasonably well for values close to 0, you've got to
be clever about picking where the Taylor series should take over, and
also introducing a branch condition is a bad idea in a fundamental
library, especially one marked `#[inline]`

, as this one is: it could
mess with vectorization or lead to worse code generation.

So let's look for a better way to solve it. Since this issue with `log`

is also somewhat common, libraries provide a helper function, `log1p`

,
which is `log`

but shifted left so the root is at 0. So we can rewrite:

log(...) -> log1p(... - 1)

On its own this doesn't actually fix the problem, but it makes the
cancellation that *implicitly* happens inside logarithm into an explicit
cancellation. Now we can work to solve that.

Here the fix is to move the `-1`

to be next to the `sqrt`

call, which is
what initially produces a value close to 1:

log(x + (sqrt(x*x + 1) - 1))

Now, with cancellations like this, between square roots (think of `1`

as
`sqrt(1)`

) there is a classic trick that I remember from my work on the
quadratic formula: the difference of squares. Basically, when you have
`sqrt(a) - sqrt(b)`

, multiple this by `sqrt(a) + sqrt(b)`

over itself, and
use the difference of squares formula to simplify the numerator to
remove the square roots:

sqrt(x*x + 1) - 1 -> [ (x*x + 1) - 1 ] / [ sqrt(x*x + 1) + 1 ]

Now the numerator can be simplified to `x * x`

, which doesn't have a
cancellation, while the denominator now adds positive quantities, so
it doesn't have any cancellation:

x * x / (sqrt(x*x + 1) + 1)

Now we just need to remember our other problem, overflow. We can
remove the `x * x`

in the denominator using `hypot`

, but we've also now
got an `x * x`

in the numerator.

As before, while `x * x`

may overflow in the numerator, the denominator
will be similar in size to `x`

for large `x`

, so we're really computing
something like `x * x / x`

, which doesn't have to overflow. We just need
to avoid the intermediate `x * x`

.

One way to fix it is to divide the numerator and denominator by `x`

.
Then we have:

x * x / (sqrt(x*x + 1) + 1) -> x / (sqrt(x*x + 1)/x + 1/x) -> x / (sqrt(1 + 1/x*1/x) + 1/x) -> x / (hypot(1, 1/x) + 1/x)

Now the numerator is approximately size `x`

while the denominator is
size `1/x`

, and nothing will overflow.

Another fix that I just came up with while writing this blog post is to zoom out and look at the full term:

log1p(x + x*x / (sqrt(x*x + 1) + 1))

We can factor an `x`

out from the argument to `log1p`

, thereby again
avoiding the cancellation:

log1p(x * (1 + x / (hypot(1, x) + 1)))

I'm not going to go back and contribute this simplification to Rust, but since it replaces a division by a multiply it is probably minutely faster than the version I contributed on some architectures.

## What does Herbie think?

Of course any time I do any numerics work I think back to "would Herbie do as well". Here, it did OK but not great.

Noticing both problems went well. If I didn't have Herbie, I would have missed the small-values problem and end up with a much worse solution.

Fixing the large-values problem also went great. Herbie correctly
identified the `hypot`

-based solution. It's kind of easy, though.

However, fixing the small-values problem wasn't as great. Herbie suggested a Taylor series, which is both less elegant than my fix (branches, etc) but also less accurate near the branch. I'd like to look into whether Herbie can derive my version, or if not, why not.

Finally, for my "fixed" version, Herbie does a great job showing that the result is accurate and even identifying some subtle remaining issues. For example, let's look at the "final" version above:

log1p(x * (1 + x / (hypot(1, x) + 1)))

Note that we have an intermediate step where we multiply `x`

by some
value. If `x`

is very large, that value will be just under `2`

, and for
*super duper* large values of `x`

, `x * 2`

can overflow.

There *is* a possible fix here. Let's expand out the `log1p`

for a second:

log1p(x * (1 + x / (hypot(1, x) + 1))) -> log(1 + x * (1 + x / (hypot(1, x) + 1))) -> log(x * (1/x + 1 + x / (hypot(1, x) + 1))) -> log(x) + log(1 + 1/x + x / (hypot(1, x) + 1)) -> log(x) + log1p(1/x + x / (hypot(1, x) + 1))

Now we're not multiplying `x`

by anything at all! But for very small
values of `x`

, we'll get cancellation between the `log(x)`

and `log1p(1/x)`

terms. If we add a branch at 1 we can get the best of both worlds, but
there's no way overflow at extremely large `x`

is a big enough issue to
deserve adding a branch here, and we're also now nearly doubling the
cost of `asinh`

by computing two different logarithms. So this subtle
issue found by Herbie doesn't strike me as worth fixing.

## Contribution

Anyway, I packaged up my suggestions to `asinh`

, as well as somewhat
similar fixes to `acosh`

, and submitted them to Rust in #104548, and
then Max, who by the way is on the job market, implemented, tested,
and got it all merged into Rust in #104553. It's now in Rust nightlies
and will be released in the next Rust version!

## Footnotes:

^{1}

Maybe there's a platform support issue?

^{2}

Plus the right logic for handling signs.

^{3}

Ignoring subnormals and some other caveats