# Batches for Recursive Data Structures

Most of Herbie is spent thinking about expressions: generating them, evaluating them, mutating them, analyzing them, filtering them, and so on. So naturally this has to be as fast as possible.

Herbie is a primarily *mutational* synthesis engine—most expressions
it thinks about are generated by changing, in small ways, another
expression—so most of its expressions are very similar and share
subexpressions. This makes duplication a big problem: Herbie wastes a
lot of time analyzing identical subexpressions repeatedly. Caching
things can help, but caches introduce their own problems: hashing
large expressions is slow, hash tables are slow, garbage collection
becomes an issue, and memory usage increases. We are therefore
rewriting Herbie to focus on *batches*.

A batch is a linear encoding of a recursive data structure. This is a well-known performance trick (Adrian Sampson has a good walkthrough), but the details of actually working with this representation have been challenging, and in this blog post I want to walk through some of the things I've learned.

## The basics of batches

A batch stores an expression by storing each node in a vector, with
back-references to other nodes. An expression like `x * tan(x)`

ends up
stored by a batch of three elements:

%0 = x %1 = tan(%0) %2 = %0 * %1

Batches never contain duplicates (they are hashconsed), so for expressions with a lot of shared structure they can save a lot of memory.

In Herbie, which is written in Racket, the precise representation is a vector of nodes, where each node is a variable, a numeric constant, or a list where the first element is an operator name and the rest are indexes into the vector for the backreferences:

#(x (tan 0) (* 0 1))

To convert an expression into this form, you can recursively traverse the expression, adding every node and returning its index in the vector. As you do that, you can deduplicate by storing a hash table mapping nodes to indices. In Herbie, we have "mutable batches" which are a hash table plus a list, and "immutable batches" which store just the vector.

## Computing over batches

Doing normal programming with batches has a common pattern. For example, suppose you want to convert from a batch to an expression. There's a bad way and a good way. The bad way is recursive:

(define (batch-get batch idx) (match (vector-ref batch idx) [(list op refs ...) (list* op (map (curry batch-get batch) refs))] [atom atom]))

The reason this is bad is that if the expression has a lot of shared subexpressions, you end up traversing those subexpressions a lot of times, and you end up allocating memory for each copy. In the worst case, this is exponential in both space and time. The good way is iterative:

(define (batch-get batch idx) (define batch-exprs (make-vector (vector-length batch))) (for ([node (in-vector batch)] [idx (in-naturals)]) (define expr (match node [(list op refs ...) (list* op (map (curry vector-ref batch-exprs) refs))] [atom atom])) (vector-set! batch-exprs idx expr)) (vector-ref batch-exprs idx))

It's way more code, but it uses a common pattern with batches: instead
of writing a recursive traversal, we allocate a temporary vector that
will store all the recursive outputs, and then fill it with a simple
`for`

loop over batch nodes. Note that as we fill the `batch-exprs`

vector, we also reference previous elements using the back-references
in the input batch.

Here's another example. Suppose we want to compute the depth of an expression, but stored as a batch instead of as a normal expression tree. You can actually compute depth bottom-up (the depth of an expression is one plus the depth of the deepest subexpression) or top-down (the depth of the root is 0, the depth of a subexpression is one more than the depth of its parent). When you recurse, there's a subtle difference, because one of these options is tail-recursive and one isn't. But we can do either one with batches.

Here's bottom-up as a recursive function:

(define (depth expr) (match node [(list op args ...) (define child-depths (map depth args)) (+ 1 (apply max child-depths))] [other 0]))

Here's the same thing on a batch:

(define (batch-depth batch) (define depths (make-vector (vector-length batch))) (for ([node (in-vector batch)] [idx (in-naturals)]) (vector-set! depths idx (match node [(list op args ...) (define child-depths (map (curry vector-ref depths) args)) (+ 1 (apply max child-depths))] [other 0]))) (vector-last depths))

It's the same body, basically, except we allocate a vector of outputs and instead of making recursive calls we look in that output vector.

Top-down depth is a little trickier even recursively: we need to
update a mutable reference (a `box`

in Racket) every time we reach a
leaf node:

(define (depth node output [acc 0]) (match node [(list op args ...) (for ([arg (in-list args)]) (depth arg output (+ acc 1)))] [other (box-set! output (max (box-ref output) acc))]))

The caller has to allocate the output cell and read the answer from
it. The default `acc`

value of 0 gives the depth of the root of the
tree. Note that I'm assume operators have more than one argument; this
is an example, not necessarily live code. Here's the same thing but
with a batch:

(define (batch-depth batch output) (define accs (make-vector (vector-length batch) 0)) (define last-idx (- (vector-length batch) 1)) (vector-set! accs last-idx 0) (for ([node (in-vector batch last-idx 0 -1)] [idx (in-naturals)]) (define acc (vector-ref accs idx)) (match node [(list op args ...) (for ([arg (in-list args)]) (vector-set! accs arg (max (vector-ref accs arg) (+ acc 1))))] [other (box-set! output (max (box-ref output) acc))])))

This *also* looks like the recursive version, but this time the
temporary vector tracks *inputs* to the recursive call instead of
outputs from it. Also, there's one tricky difference: the recursive
version can go into the same subtree twice with different `acc`

arguments. The batch version can't, so the batch version computes the
`max`

of all recursive calls to the same node. So for top-down
traversals, a little bit of extra logic is needed to convert from
expressions to batches. Otherwise, though, there's again a lot of
shared code.

More generally, any bottom-up tree traversal becomes a loop through the nodes in order with a side table of recursive outputs, while any top-down tree traversal becomes a loop through the nodes in reverse order with a side table of recursive inputs. A recursive algorithm on expressions can be broken down into a sequence of bottom-up and top-down traversals in order to do it on batches.

Here's one realistic example: A batch can contain "zombie" nodes not referenced by the root node. (This can happen due to transformations, for example.) You can remove them through two traversals. First a top-down traversal marks all nodes reachable from the root. Then a bottom-up traversal copies the reachable nodes to a new vector. (The top-down depth code above, by the way, isn't correct in the presence of zombie nodes.)

This pattern of turning recursion into iteration means you might do
*more* calls than you need to, but you'll never do *duplicate* calls. This
is typically a win. It might also lead to less allocation, though I
think in a language like Racket that impact is hard to predict. In
something like C or Rust, where the nodes can be stored directly in
the vector, the allocation impact will be real and large, and you'll
also get potential benefits from things like cache locality and
prefetching.

## Performing rewrites

Recursive functions that return recursive data structures are a little tricker.

For example, suppose you want to write a function that finds and
replaces all `(+ a (+ b c))`

with `(+ (+ a b) c)`

. Since the input is
expressions and the output is expressions, this is a function from
batch to batch. Since it is bottom-up, it's implemented with a `for`

loop forward over the input batch.

Note that the cardinality of the input and output batches might
differ. That's true even for this specific associativity rewrite,
which replaces two nodes with two other nodes: the rewrite can cause
there to be more or less sharing. In `(* (+ 1 2) (+ 1 (+ 2 3)))`

, the
original expression's 3 addition nodes are all different but after
rewriting the subexpression `(+ 1 2)`

appears twice.

So to implement this rewrite, you need to have three vectors: the
input batch; the output batch; and a *reindexing* vector that goes from
the index of the input node in the input batch to index of the
rewritten node in the output batch.

It looks something like this, where `b1`

is the input (immutable) batch,
`b2`

is the output (mutable) batch, and `reindex`

is the reindexing
vector. The `batch-add!`

method adds a node to a mutable batch and
returns its index; it'll deduplicate the node when adding so the index
might not be the end of the batch.

(for ([node (in-vector b1)] [idx (in-naturals)]) (match node [(list '+ a rhs) (match (vector-ref b1 rhs) [(list '+ b c) (define idx-a (vector-ref reindex a)) (define idx-b (vector-ref reindex b)) (define idx-c (vector-ref reindex c)) (define idx-a+b (batch-add! b2 (list '+ idx-a idx-b))) (define idx-lhs+c (batch-add! b2 (list '+ idx-a+b idx-c))) (vector-set! reindex idx idx-lhs+c)])])))

Note that in the pattern matching, you end up needing to index back
into the input batch to figure out what the backreferences mean. Then,
you need to translate the remaining backreferences using the
reindexing vector before adding the rewritten nodes into the output
batch. Also I'm not handling the `else`

case, which ends up
consuming many more lines of code.

The upshot is that doing this kind of rewrites is possible, but also kind of a pain. And dangerous! If you forget to reindex one of the backreferences you end up with garbage, and this is a very hard bug to debug.

## Batchrefs

One solution to the danger here is types, or, since this is Racket, dynamic type checks.

The key abstraction is the "batchref", which is an index into a batch,
but wrapped in a way that keeps track of the batch. You can `deref`

a
batchref to index into the batch and pull out "one level" of the node:

(struct batch-ref (batch idx)) (define (deref bref) (define node (vector-ref (batch-ref-batch bref) (batch-ref-idx bref))) (match node [(list op args ...) (list* op (map (curry batch-ref (batch-ref-batch bref)) args))] [other other]))

Note that after we pull out the node, we replace *its* back-references
with `batch-ref`

objects, to continue to enforce safety.

Then the pattern matching can be folded away like this:

(match (deref (batch-ref b1 idx)) [(list '+ a (app deref (list '+ b c))) ...] [(list op args ...) (batch-add! b2 (list* op (map (curry vector-ref reindex) args)))] [other (batch-add! b2 other)])

Here's I've used `deref`

together with a neat Racket feature to make the
pattern matching nicer. First, I construct a `batch-ref`

and `deref`

it
before entering the `match`

; this replaces all my back-references with
`batch-ref`

objects.

Then, in the pattern match, I match `(app deref ...)`

, a Racket `app`

pattern. This matches a `batch-ref`

, calls `deref`

on it, and then matches
on the output; this lets me turn the nested `match`

calls into a flat
call, which also means I only need to write the `else`

case once. `app`

patterns are a nice feature of Racket, kind of analogous to "views" in
Scala or Python.

In the main pattern matching case, `a`

, `b`

, and `c`

are now batchrefs. The
next idea is just to construct the replacement expression out of these
batchrefs:

(match (deref (batch-ref b1 idx)) [(list '+ a (app deref (list '+ b c))) (define new-expr (list '+ (list '+ a b) c)) ...] ...)

The `new-expr`

is an expression tree where the leaves are references
into the input batch. We can then recursively add every node in that
expression, translating any batchrefs from old to new batch using the
reindexing table:

(define (batch-add-transformed! b2 reindex expr) (match expr [(batch-ref b1 idx) (vector-ref reindex idx)] [(list op args ...) (batch-add! b2 (list* op (map (curry batch-add-transformed! b2 reindex) args)))]))

Then the actual transformation looks like this:

(match (deref (batch-ref b1 idx)) [(list '+ a (app deref (list '+ b c))) (define new-idx (batch-add-transformed! b2 reindex (list '+ (list '+ a b) c))) (vector-set! reindex idx new-idx)] ...)

You can wrap this whole pattern into a helper function, so that the final invocation looks like this:

(batch-replace b1 (lambda (node) (match node [`(+ ,a ,(app deref `(+ ,b ,c))) `(+ (+ ,a ,b) ,c)] [other other])))

The `batch-replace`

helper handles all of the allocation, reindexing,
and all of that. The `node`

passed into the callback uses `batch-ref`

for
all the back-references, hence the `deref`

, and also handles
constructing the new batch and replacing all the indices. In Racket
you can even write some custom syntax around this—we're still
playing with what exactly that syntax should be and haven't yet
settled on anything.

## Batches of multiple expressions

A batch can also store multiple expressions. In Herbie this is extremely important because all of our expressions are quite similar. In one key part of Herbie we already do this, and the batch 25 times smaller than the expressions themselves.

A batch storing multiple expressions looks the same as a normal batch.
But we now also want to keep track of where each initial expression is
in the batch; we call that second vector the "rootvec". To go back
from the batch to the expressions, you use the iterative algorithm
above, and at the end read one expression for each entry of the
`rootvec`

. It's pretty fast.

As soon as you have multiple expressions, it becomes even easier to have nodes not referenced from one of the "root nodes". We call these "zombie" nodes and you have to "cull" them every now and then. If you think about it, the batch is kind of like an arena allocator for expressions, and zombie culling is just manual garbage collection. The same top-down-then-bottom-up pattern described earlier works. It's basically a mark-copy collector.

There's a bit of a trade-off here—if you have too many "zombie" nodes, you can waste time analyzing or rewriting those nodes, but if you "cull" the zombie nodes too often that also takes a lot of time. This is the normal garbage collection frequency tradeoff. So sometimes you make judgements calls about whether or not to delete those zombie nodes.

For example, Herbie's series expander rewrites `pow(x, 1/2)`

into
`sqrt(x)`

, because the series expander thinks about real semantics and
it has custom `sqrt`

code that's faster than its generic `pow`

code. The
rewrite itself is done just like above, using `batch-replace`

. But it
can leave a zombie `1/2`

node in the batch, if the `pow`

was the only
expression that used `1/2`

. We could cull the zombies afterwards—but
series-expanding a constant is very cheap, so it's better to just
leave the zombie node around.

## A type perspective

One big problem working with batches is that you end up with lots of vectors. A batch of multiple expressions is already two vectors (the batch and the rootvec). A rewrite involves reindexing vector and two different batches. Then you map over the batches to produce new vectors. Pretty quickly you have a dozen vectors and it's hard to remember which vector is parallel to which other vector, and which indices need to be transformed how. It's a mess.

I find it helpful to think about this from a type perspective. Imagine
a vector not as a type `vector<T>`

but as a type `vector<α, T>`

where `α`

represents the "index space". You can think of it as the vector
length, but I prefer thinking of it as a phantom type, totally
meaningless on its own. In practice, we represent the index space as a
reference to a batch. Each batch is its own index space, and then if
you compute some value for each node in a batch, that has the same
index space as the batch.

If you have two vectors with the same length, like in the depth code above, then both vectors have the same index space:

batch : vector<α, Node> depths : vector<α, Nat>

When two vectors have the same index space you can traverse them in parallel.

Alternatively, if you have two vectors of different lengths, you think of them as having different index spaces:

input-batch : vector<α, Node> output-batch : vector<β, Node>

When you have a reindexing vector, it maps from one indexing space to another:

reindex : vector<α, β>

A `batch-ref`

is a batch plus an index, which we think of as a member of
the index space. So:

batch-ref<α> := vector<α, Node> × α deref : vector<α, Node> × α → Node

That said, it's sometimes better to identify `batch-ref<α>`

with `α`

itself, that is, to think about a `batch-ref`

as a kind of safe index
type. I already mentioned that sometimes you want to represent the
type `α`

at runtime as a reference to a batch, so then I think of a
`batch-ref<α>`

as storing the batch because it's storing its type
parameter (represented at runtime with a batch pointer) and storing
the index because all it's actually doing is wrapping the index.

Keeping the types in mind reminds you transform vectors appropriately. So for example, if you have a vector of expressions, you can transform them into a single batch using a method:

exprs2batch : vector<α, Expr> → ∃ β. vector<α, β> × vector<β, Node>

Note that the `β`

type parameter is existentially bound here, that is,
you do not a priori know how big the output batch will be, but you can
still safely access its elements because you have that rootvec to map
from `α`

to `β`

.

## Batches for multiple types

Batches make a lot of sense for expressions, but actually there are lots of different recursive data structures in Herbie: expressions, derivations, types, and so on. We probably want to use batches for all of them (though expressions are definitely most common and most performance-sensitive).

I find the types particularly eluminating in this case. The place to start is the concept of the μ type:

μ<F> := F<μ<F>>

This is not a helpful definition, so instead let's look at an example. An expression is either literal or an operator applied to some arguments:

Expr := Literal | Apply(Op, List<Expr>)

The "recursive shape" of this type can be described by this
non-recursive, parameterized type:^{1} [^{1} By the way, in an earlier draft
of this blog post I made a total mess of the type theory but James set
me straight. Thanks, James!]

ExprF<T> := Literal | Apply(Op, List<T>)

The μ type then goes from this "functor" to the recursive structure,
recovering `Expr = μ<ExprF>`

. So how does this connect to batches?

Well, it answers, for one, what a "node" is, which I did by example
above. A node is the algebra of an expression, parameterized by
indices (backreferences). In other words, by `vector<α, Node>`

, we mean:

vector<α, ExprF<α>>

The type `ExprF<α>`

means "expression, but with all the backreferences
replaced by indices into `α`

". So `expr2batch`

can be generalized into a
function `flatten`

like this:

flatten : vector<α, μ<F>> → ∃ β. vector<α, β> × vector<β, F<β>>

This is the kind of thing people make fun of Haskell for: one-line
definitions that take three pages of explanation. Anyway, I write
`batch<β, F>`

for `vector<β, F<β>>`

.

In Racket, we don't have all these types but there's a more basic dynamic programming option. Instead of μ types and all that jazz, we have a "recursor":

(define (expr-recursor expr f) (match expr [(list op args ...) (list* op (map f args))] [other other]))

This is like a recursive walk over expressions, except the recursive
call is replaced with `f`

. In type theory terms it is the `fmap`

of the
functor.

Then a batch stores its recursor, and `deref`

works like this:

(define (deref bref) (match-define (batch-ref batch idx) bref) (match-define (batch nodes recursor) batch) (expr-recursor (vector-ref nodes idx) (lambda (idx) (batch-ref batch idx))))

This is now generic over recursors, and other helper functions like
`batch-replace`

and zombie culling can also be made generic in this way.

Moreover, if you want to write a function that transforms one recursive structure to another (for example, type checking an expression) you just need to take care to create input and output batches with different recursors. In a typed language you wouldn't need to store the recursor—that would be done with and enforced by types—but in Racket this is still pretty ergonomic.

(Specifically, in Racket, helper functions like `batch-replace`

that
output a batch have a variant, `batch-replace!`

, that output into a
caller-supplied mutable batch. You can construct that mutable batch
with whatever recursor you want. In Racket, having a mutable- and
immutable- versions of the same data structure and the same methods is
fairly common.)

It's also possible to have data structures that mix multiple types of recursive data structures. For example, if you have expressions and types, and the expressions contain type annotations, you really have something like this:

TypeF<T> := BaseT | FnT(T, T) Type := μ<T> ExprF<Type><T> := Var(Int) | App(T, T) | Lam(Int, Type, T) Expr := μ<ExprF<Type>>

Note that the type algebra is standard but the expression algebra has
parameter `T`

for subexpressions but also a parameter `Type`

for types.
(Which I've written curried.) If you want to store a bunch of
expressions (and their associated types) in batches, it'd look like
this:

exprs : batch<α, ExprF<β>> = vector<α, ExprF<β><α>> types : batch<β, TypeF> = vector<β, TypeF<β>>

In other words, the batch of expressions `exprs`

stores expressions,
except with subexpressions replaced by backreferences and with types
replaced by references into the batch of types, which itself stores
backreferences to itself. While this is obviously kind-of confusing
you can keep it somewhat straight with types or recursors.

## Conclusion

In Herbie we are making a big push to use batches everywhere we can. I think it will make Herbie faster and use less memory.

Of course, Herbie already has a lot of code (about 15k lines!) and the conversion will be slow, and right now we lack experience working with batches. The examples above show how to convert basic functions over expressions into ones that work on batches, and the abstractions like batch-refs and recusors show that it's possible to operate on a batches at a reasonably high level.

As my colleagues and I get more experience with batches, we hope to discover even more such abstractions and ultimately make batches easy to work with.

## Footnotes:

^{1}

By the way, in an earlier draft of this blog post I made a total mess of the type theory but James set me straight. Thanks, James!