Batches for Recursive Data Structures
Most of Herbie is spent thinking about expressions: generating them, evaluating them, mutating them, analyzing them, filtering them, and so on. So naturally this has to be as fast as possible.
Herbie is a primarily mutational synthesis engine—most expressions it thinks about are generated by changing, in small ways, another expression—so most of its expressions are very similar and share subexpressions. This makes duplication a big problem: Herbie wastes a lot of time analyzing identical subexpressions repeatedly. Caching things can help, but caches introduce their own problems: hashing large expressions is slow, hash tables are slow, garbage collection becomes an issue, and memory usage increases. We are therefore rewriting Herbie to focus on batches.
A batch is a linear encoding of a recursive data structure. This is a well-known performance trick (Adrian Sampson has a good walkthrough), but the details of actually working with this representation have been challenging, and in this blog post I want to walk through some of the things I've learned.
The basics of batches
A batch stores an expression by storing each node in a vector, with
back-references to other nodes. An expression like x * tan(x)
ends up
stored by a batch of three elements:
%0 = x %1 = tan(%0) %2 = %0 * %1
Batches never contain duplicates (they are hashconsed), so for expressions with a lot of shared structure they can save a lot of memory.
In Herbie, which is written in Racket, the precise representation is a vector of nodes, where each node is a variable, a numeric constant, or a list where the first element is an operator name and the rest are indexes into the vector for the backreferences:
#(x (tan 0) (* 0 1))
To convert an expression into this form, you can recursively traverse the expression, adding every node and returning its index in the vector. As you do that, you can deduplicate by storing a hash table mapping nodes to indices. In Herbie, we have "mutable batches" which are a hash table plus a list, and "immutable batches" which store just the vector.
Computing over batches
Doing normal programming with batches has a common pattern. For example, suppose you want to convert from a batch to an expression. There's a bad way and a good way. The bad way is recursive:
(define (batch-get batch idx) (match (vector-ref batch idx) [(list op refs ...) (list* op (map (curry batch-get batch) refs))] [atom atom]))
The reason this is bad is that if the expression has a lot of shared subexpressions, you end up traversing those subexpressions a lot of times, and you end up allocating memory for each copy. In the worst case, this is exponential in both space and time. The good way is iterative:
(define (batch-get batch idx) (define batch-exprs (make-vector (vector-length batch))) (for ([node (in-vector batch)] [idx (in-naturals)]) (define expr (match node [(list op refs ...) (list* op (map (curry vector-ref batch-exprs) refs))] [atom atom])) (vector-set! batch-exprs idx expr)) (vector-ref batch-exprs idx))
It's way more code, but it uses a common pattern with batches: instead
of writing a recursive traversal, we allocate a temporary vector that
will store all the recursive outputs, and then fill it with a simple
for
loop over batch nodes. Note that as we fill the batch-exprs
vector, we also reference previous elements using the back-references
in the input batch.
Here's another example. Suppose we want to compute the depth of an expression, but stored as a batch instead of as a normal expression tree. You can actually compute depth bottom-up (the depth of an expression is one plus the depth of the deepest subexpression) or top-down (the depth of the root is 0, the depth of a subexpression is one more than the depth of its parent). When you recurse, there's a subtle difference, because one of these options is tail-recursive and one isn't. But we can do either one with batches.
Here's bottom-up as a recursive function:
(define (depth expr) (match node [(list op args ...) (define child-depths (map depth args)) (+ 1 (apply max child-depths))] [other 0]))
Here's the same thing on a batch:
(define (batch-depth batch) (define depths (make-vector (vector-length batch))) (for ([node (in-vector batch)] [idx (in-naturals)]) (vector-set! depths idx (match node [(list op args ...) (define child-depths (map (curry vector-ref depths) args)) (+ 1 (apply max child-depths))] [other 0]))) (vector-last depths))
It's the same body, basically, except we allocate a vector of outputs and instead of making recursive calls we look in that output vector.
Top-down depth is a little trickier even recursively: we need to
update a mutable reference (a box
in Racket) every time we reach a
leaf node:
(define (depth node output [acc 0]) (match node [(list op args ...) (for ([arg (in-list args)]) (depth arg output (+ acc 1)))] [other (box-set! output (max (box-ref output) acc))]))
The caller has to allocate the output cell and read the answer from
it. The default acc
value of 0 gives the depth of the root of the
tree. Note that I'm assume operators have more than one argument; this
is an example, not necessarily live code. Here's the same thing but
with a batch:
(define (batch-depth batch output) (define accs (make-vector (vector-length batch) 0)) (define last-idx (- (vector-length batch) 1)) (vector-set! accs last-idx 0) (for ([node (in-vector batch last-idx 0 -1)] [idx (in-naturals)]) (define acc (vector-ref accs idx)) (match node [(list op args ...) (for ([arg (in-list args)]) (vector-set! accs arg (max (vector-ref accs arg) (+ acc 1))))] [other (box-set! output (max (box-ref output) acc))])))
This also looks like the recursive version, but this time the
temporary vector tracks inputs to the recursive call instead of
outputs from it. Also, there's one tricky difference: the recursive
version can go into the same subtree twice with different acc
arguments. The batch version can't, so the batch version computes the
max
of all recursive calls to the same node. So for top-down
traversals, a little bit of extra logic is needed to convert from
expressions to batches. Otherwise, though, there's again a lot of
shared code.
More generally, any bottom-up tree traversal becomes a loop through the nodes in order with a side table of recursive outputs, while any top-down tree traversal becomes a loop through the nodes in reverse order with a side table of recursive inputs. A recursive algorithm on expressions can be broken down into a sequence of bottom-up and top-down traversals in order to do it on batches.
Here's one realistic example: A batch can contain "zombie" nodes not referenced by the root node. (This can happen due to transformations, for example.) You can remove them through two traversals. First a top-down traversal marks all nodes reachable from the root. Then a bottom-up traversal copies the reachable nodes to a new vector. (The top-down depth code above, by the way, isn't correct in the presence of zombie nodes.)
This pattern of turning recursion into iteration means you might do more calls than you need to, but you'll never do duplicate calls. This is typically a win. It might also lead to less allocation, though I think in a language like Racket that impact is hard to predict. In something like C or Rust, where the nodes can be stored directly in the vector, the allocation impact will be real and large, and you'll also get potential benefits from things like cache locality and prefetching.
Performing rewrites
Recursive functions that return recursive data structures are a little tricker.
For example, suppose you want to write a function that finds and
replaces all (+ a (+ b c))
with (+ (+ a b) c)
. Since the input is
expressions and the output is expressions, this is a function from
batch to batch. Since it is bottom-up, it's implemented with a for
loop forward over the input batch.
Note that the cardinality of the input and output batches might
differ. That's true even for this specific associativity rewrite,
which replaces two nodes with two other nodes: the rewrite can cause
there to be more or less sharing. In (* (+ 1 2) (+ 1 (+ 2 3)))
, the
original expression's 3 addition nodes are all different but after
rewriting the subexpression (+ 1 2)
appears twice.
So to implement this rewrite, you need to have three vectors: the input batch; the output batch; and a reindexing vector that goes from the index of the input node in the input batch to index of the rewritten node in the output batch.
It looks something like this, where b1
is the input (immutable) batch,
b2
is the output (mutable) batch, and reindex
is the reindexing
vector. The batch-add!
method adds a node to a mutable batch and
returns its index; it'll deduplicate the node when adding so the index
might not be the end of the batch.
(for ([node (in-vector b1)] [idx (in-naturals)]) (match node [(list '+ a rhs) (match (vector-ref b1 rhs) [(list '+ b c) (define idx-a (vector-ref reindex a)) (define idx-b (vector-ref reindex b)) (define idx-c (vector-ref reindex c)) (define idx-a+b (batch-add! b2 (list '+ idx-a idx-b))) (define idx-lhs+c (batch-add! b2 (list '+ idx-a+b idx-c))) (vector-set! reindex idx idx-lhs+c)])])))
Note that in the pattern matching, you end up needing to index back
into the input batch to figure out what the backreferences mean. Then,
you need to translate the remaining backreferences using the
reindexing vector before adding the rewritten nodes into the output
batch. Also I'm not handling the else
case, which ends up
consuming many more lines of code.
The upshot is that doing this kind of rewrites is possible, but also kind of a pain. And dangerous! If you forget to reindex one of the backreferences you end up with garbage, and this is a very hard bug to debug.
Batchrefs
One solution to the danger here is types, or, since this is Racket, dynamic type checks.
The key abstraction is the "batchref", which is an index into a batch,
but wrapped in a way that keeps track of the batch. You can deref
a
batchref to index into the batch and pull out "one level" of the node:
(struct batch-ref (batch idx)) (define (deref bref) (define node (vector-ref (batch-ref-batch bref) (batch-ref-idx bref))) (match node [(list op args ...) (list* op (map (curry batch-ref (batch-ref-batch bref)) args))] [other other]))
Note that after we pull out the node, we replace its back-references
with batch-ref
objects, to continue to enforce safety.
Then the pattern matching can be folded away like this:
(match (deref (batch-ref b1 idx)) [(list '+ a (app deref (list '+ b c))) ...] [(list op args ...) (batch-add! b2 (list* op (map (curry vector-ref reindex) args)))] [other (batch-add! b2 other)])
Here's I've used deref
together with a neat Racket feature to make the
pattern matching nicer. First, I construct a batch-ref
and deref
it
before entering the match
; this replaces all my back-references with
batch-ref
objects.
Then, in the pattern match, I match (app deref ...)
, a Racket app
pattern. This matches a batch-ref
, calls deref
on it, and then matches
on the output; this lets me turn the nested match
calls into a flat
call, which also means I only need to write the else
case once. app
patterns are a nice feature of Racket, kind of analogous to "views" in
Scala or Python.
In the main pattern matching case, a
, b
, and c
are now batchrefs. The
next idea is just to construct the replacement expression out of these
batchrefs:
(match (deref (batch-ref b1 idx)) [(list '+ a (app deref (list '+ b c))) (define new-expr (list '+ (list '+ a b) c)) ...] ...)
The new-expr
is an expression tree where the leaves are references
into the input batch. We can then recursively add every node in that
expression, translating any batchrefs from old to new batch using the
reindexing table:
(define (batch-add-transformed! b2 reindex expr) (match expr [(batch-ref b1 idx) (vector-ref reindex idx)] [(list op args ...) (batch-add! b2 (list* op (map (curry batch-add-transformed! b2 reindex) args)))]))
Then the actual transformation looks like this:
(match (deref (batch-ref b1 idx)) [(list '+ a (app deref (list '+ b c))) (define new-idx (batch-add-transformed! b2 reindex (list '+ (list '+ a b) c))) (vector-set! reindex idx new-idx)] ...)
You can wrap this whole pattern into a helper function, so that the final invocation looks like this:
(batch-replace b1 (lambda (node) (match node [`(+ ,a ,(app deref `(+ ,b ,c))) `(+ (+ ,a ,b) ,c)] [other other])))
The batch-replace
helper handles all of the allocation, reindexing,
and all of that. The node
passed into the callback uses batch-ref
for
all the back-references, hence the deref
, and also handles
constructing the new batch and replacing all the indices. In Racket
you can even write some custom syntax around this—we're still
playing with what exactly that syntax should be and haven't yet
settled on anything.
Batches of multiple expressions
A batch can also store multiple expressions. In Herbie this is extremely important because all of our expressions are quite similar. In one key part of Herbie we already do this, and the batch 25 times smaller than the expressions themselves.
A batch storing multiple expressions looks the same as a normal batch.
But we now also want to keep track of where each initial expression is
in the batch; we call that second vector the "rootvec". To go back
from the batch to the expressions, you use the iterative algorithm
above, and at the end read one expression for each entry of the
rootvec
. It's pretty fast.
As soon as you have multiple expressions, it becomes even easier to have nodes not referenced from one of the "root nodes". We call these "zombie" nodes and you have to "cull" them every now and then. If you think about it, the batch is kind of like an arena allocator for expressions, and zombie culling is just manual garbage collection. The same top-down-then-bottom-up pattern described earlier works. It's basically a mark-copy collector.
There's a bit of a trade-off here—if you have too many "zombie" nodes, you can waste time analyzing or rewriting those nodes, but if you "cull" the zombie nodes too often that also takes a lot of time. This is the normal garbage collection frequency tradeoff. So sometimes you make judgements calls about whether or not to delete those zombie nodes.
For example, Herbie's series expander rewrites pow(x, 1/2)
into
sqrt(x)
, because the series expander thinks about real semantics and
it has custom sqrt
code that's faster than its generic pow
code. The
rewrite itself is done just like above, using batch-replace
. But it
can leave a zombie 1/2
node in the batch, if the pow
was the only
expression that used 1/2
. We could cull the zombies afterwards—but
series-expanding a constant is very cheap, so it's better to just
leave the zombie node around.
A type perspective
One big problem working with batches is that you end up with lots of vectors. A batch of multiple expressions is already two vectors (the batch and the rootvec). A rewrite involves reindexing vector and two different batches. Then you map over the batches to produce new vectors. Pretty quickly you have a dozen vectors and it's hard to remember which vector is parallel to which other vector, and which indices need to be transformed how. It's a mess.
I find it helpful to think about this from a type perspective. Imagine
a vector not as a type vector<T>
but as a type vector<α, T>
where α
represents the "index space". You can think of it as the vector
length, but I prefer thinking of it as a phantom type, totally
meaningless on its own. In practice, we represent the index space as a
reference to a batch. Each batch is its own index space, and then if
you compute some value for each node in a batch, that has the same
index space as the batch.
If you have two vectors with the same length, like in the depth code above, then both vectors have the same index space:
batch : vector<α, Node> depths : vector<α, Nat>
When two vectors have the same index space you can traverse them in parallel.
Alternatively, if you have two vectors of different lengths, you think of them as having different index spaces:
input-batch : vector<α, Node> output-batch : vector<β, Node>
When you have a reindexing vector, it maps from one indexing space to another:
reindex : vector<α, β>
A batch-ref
is a batch plus an index, which we think of as a member of
the index space. So:
batch-ref<α> := vector<α, Node> × α deref : vector<α, Node> × α → Node
That said, it's sometimes better to identify batch-ref<α>
with α
itself, that is, to think about a batch-ref
as a kind of safe index
type. I already mentioned that sometimes you want to represent the
type α
at runtime as a reference to a batch, so then I think of a
batch-ref<α>
as storing the batch because it's storing its type
parameter (represented at runtime with a batch pointer) and storing
the index because all it's actually doing is wrapping the index.
Keeping the types in mind reminds you transform vectors appropriately. So for example, if you have a vector of expressions, you can transform them into a single batch using a method:
exprs2batch : vector<α, Expr> → ∃ β. vector<α, β> × vector<β, Node>
Note that the β
type parameter is existentially bound here, that is,
you do not a priori know how big the output batch will be, but you can
still safely access its elements because you have that rootvec to map
from α
to β
.
Batches for multiple types
Batches make a lot of sense for expressions, but actually there are lots of different recursive data structures in Herbie: expressions, derivations, types, and so on. We probably want to use batches for all of them (though expressions are definitely most common and most performance-sensitive).
I find the types particularly eluminating in this case. The place to start is the concept of the μ type:
μ<F> := F<μ<F>>
This is not a helpful definition, so instead let's look at an example. An expression is either literal or an operator applied to some arguments:
Expr := Literal | Apply(Op, List<Expr>)
The "recursive shape" of this type can be described by this non-recursive, parameterized type:1 [1 By the way, in an earlier draft of this blog post I made a total mess of the type theory but James set me straight. Thanks, James!]
ExprF<T> := Literal | Apply(Op, List<T>)
The μ type then goes from this "functor" to the recursive structure,
recovering Expr = μ<ExprF>
. So how does this connect to batches?
Well, it answers, for one, what a "node" is, which I did by example
above. A node is the algebra of an expression, parameterized by
indices (backreferences). In other words, by vector<α, Node>
, we mean:
vector<α, ExprF<α>>
The type ExprF<α>
means "expression, but with all the backreferences
replaced by indices into α
". So expr2batch
can be generalized into a
function flatten
like this:
flatten : vector<α, μ<F>> → ∃ β. vector<α, β> × vector<β, F<β>>
This is the kind of thing people make fun of Haskell for: one-line
definitions that take three pages of explanation. Anyway, I write
batch<β, F>
for vector<β, F<β>>
.
In Racket, we don't have all these types but there's a more basic dynamic programming option. Instead of μ types and all that jazz, we have a "recursor":
(define (expr-recursor expr f) (match expr [(list op args ...) (list* op (map f args))] [other other]))
This is like a recursive walk over expressions, except the recursive
call is replaced with f
. In type theory terms it is the fmap
of the
functor.
Then a batch stores its recursor, and deref
works like this:
(define (deref bref) (match-define (batch-ref batch idx) bref) (match-define (batch nodes recursor) batch) (expr-recursor (vector-ref nodes idx) (lambda (idx) (batch-ref batch idx))))
This is now generic over recursors, and other helper functions like
batch-replace
and zombie culling can also be made generic in this way.
Moreover, if you want to write a function that transforms one recursive structure to another (for example, type checking an expression) you just need to take care to create input and output batches with different recursors. In a typed language you wouldn't need to store the recursor—that would be done with and enforced by types—but in Racket this is still pretty ergonomic.
(Specifically, in Racket, helper functions like batch-replace
that
output a batch have a variant, batch-replace!
, that output into a
caller-supplied mutable batch. You can construct that mutable batch
with whatever recursor you want. In Racket, having a mutable- and
immutable- versions of the same data structure and the same methods is
fairly common.)
It's also possible to have data structures that mix multiple types of recursive data structures. For example, if you have expressions and types, and the expressions contain type annotations, you really have something like this:
TypeF<T> := BaseT | FnT(T, T) Type := μ<T> ExprF<Type><T> := Var(Int) | App(T, T) | Lam(Int, Type, T) Expr := μ<ExprF<Type>>
Note that the type algebra is standard but the expression algebra has
parameter T
for subexpressions but also a parameter Type
for types.
(Which I've written curried.) If you want to store a bunch of
expressions (and their associated types) in batches, it'd look like
this:
exprs : batch<α, ExprF<β>> = vector<α, ExprF<β><α>> types : batch<β, TypeF> = vector<β, TypeF<β>>
In other words, the batch of expressions exprs
stores expressions,
except with subexpressions replaced by backreferences and with types
replaced by references into the batch of types, which itself stores
backreferences to itself. While this is obviously kind-of confusing
you can keep it somewhat straight with types or recursors.
Conclusion
In Herbie we are making a big push to use batches everywhere we can. I think it will make Herbie faster and use less memory.
Of course, Herbie already has a lot of code (about 15k lines!) and the conversion will be slow, and right now we lack experience working with batches. The examples above show how to convert basic functions over expressions into ones that work on batches, and the abstractions like batch-refs and recusors show that it's possible to operate on a batches at a reasonably high level.
As my colleagues and I get more experience with batches, we hope to discover even more such abstractions and ultimately make batches easy to work with.