Pavel Panchekha

By

Share under CC-BY-SA.

Batches for Recursive Data Structures

Most of Herbie is spent thinking about expressions: generating them, evaluating them, mutating them, analyzing them, filtering them, and so on. So naturally this has to be as fast as possible.

Herbie is a primarily mutational synthesis engine—most expressions it thinks about are generated by changing, in small ways, another expression—so most of its expressions are very similar and share subexpressions. This makes duplication a big problem: Herbie wastes a lot of time analyzing identical subexpressions repeatedly. Caching things can help, but caches introduce their own problems: hashing large expressions is slow, hash tables are slow, garbage collection becomes an issue, and memory usage increases. We are therefore rewriting Herbie to focus on batches.

A batch is a linear encoding of a recursive data structure. This is a well-known performance trick (Adrian Sampson has a good walkthrough), but the details of actually working with this representation have been challenging, and in this blog post I want to walk through some of the things I've learned.

The basics of batches

A batch stores an expression by storing each node in a vector, with back-references to other nodes. An expression like x * tan(x) ends up stored by a batch of three elements:

%0 = x
%1 = tan(%0)
%2 = %0 * %1

Batches never contain duplicates (they are hashconsed), so for expressions with a lot of shared structure they can save a lot of memory.

In Herbie, which is written in Racket, the precise representation is a vector of nodes, where each node is a variable, a numeric constant, or a list where the first element is an operator name and the rest are indexes into the vector for the backreferences:

#(x
   (tan 0)
   (* 0 1))

To convert an expression into this form, you can recursively traverse the expression, adding every node and returning its index in the vector. As you do that, you can deduplicate by storing a hash table mapping nodes to indices. In Herbie, we have "mutable batches" which are a hash table plus a list, and "immutable batches" which store just the vector.

Computing over batches

Doing normal programming with batches has a common pattern. For example, suppose you want to convert from a batch to an expression. There's a bad way and a good way. The bad way is recursive:

(define (batch-get batch idx)
  (match (vector-ref batch idx)
    [(list op refs ...)
     (list* op (map (curry batch-get batch) refs))]
    [atom atom]))

The reason this is bad is that if the expression has a lot of shared subexpressions, you end up traversing those subexpressions a lot of times, and you end up allocating memory for each copy. In the worst case, this is exponential in both space and time. The good way is iterative:

(define (batch-get batch idx)
  (define batch-exprs (make-vector (vector-length batch)))
  (for ([node (in-vector batch)] [idx (in-naturals)])
    (define expr
      (match node
        [(list op refs ...)
         (list* op (map (curry vector-ref batch-exprs) refs))]
        [atom atom]))
    (vector-set! batch-exprs idx expr))
  (vector-ref batch-exprs idx))

It's way more code, but it uses a common pattern with batches: instead of writing a recursive traversal, we allocate a temporary vector that will store all the recursive outputs, and then fill it with a simple for loop over batch nodes. Note that as we fill the batch-exprs vector, we also reference previous elements using the back-references in the input batch.

Here's another example. Suppose we want to compute the depth of an expression, but stored as a batch instead of as a normal expression tree. You can actually compute depth bottom-up (the depth of an expression is one plus the depth of the deepest subexpression) or top-down (the depth of the root is 0, the depth of a subexpression is one more than the depth of its parent). When you recurse, there's a subtle difference, because one of these options is tail-recursive and one isn't. But we can do either one with batches.

Here's bottom-up as a recursive function:

(define (depth expr)
  (match node
    [(list op args ...)
     (define child-depths (map depth args))
     (+ 1 (apply max child-depths))]
    [other
     0]))

Here's the same thing on a batch:

(define (batch-depth batch)
  (define depths (make-vector (vector-length batch)))
  (for ([node (in-vector batch)] [idx (in-naturals)])
    (vector-set! depths idx
                 (match node
                   [(list op args ...)
                    (define child-depths (map (curry vector-ref depths) args))
                    (+ 1 (apply max child-depths))]
                   [other
                    0])))
  (vector-last depths))

It's the same body, basically, except we allocate a vector of outputs and instead of making recursive calls we look in that output vector.

Top-down depth is a little trickier even recursively: we need to update a mutable reference (a box in Racket) every time we reach a leaf node:

(define (depth node output [acc 0])
  (match node
    [(list op args ...)
     (for ([arg (in-list args)])
       (depth arg output (+ acc 1)))]
    [other
     (box-set! output (max (box-ref output) acc))]))

The caller has to allocate the output cell and read the answer from it. The default acc value of 0 gives the depth of the root of the tree. Note that I'm assume operators have more than one argument; this is an example, not necessarily live code. Here's the same thing but with a batch:

(define (batch-depth batch output)
  (define accs (make-vector (vector-length batch) 0))
  (define last-idx (- (vector-length batch) 1))
  (vector-set! accs last-idx 0)
  (for ([node (in-vector batch last-idx 0 -1)] [idx (in-naturals)])
    (define acc (vector-ref accs idx))
    (match node
      [(list op args ...)
       (for ([arg (in-list args)])
         (vector-set! accs arg (max (vector-ref accs arg) (+ acc 1))))]
      [other
       (box-set! output (max (box-ref output) acc))])))

This also looks like the recursive version, but this time the temporary vector tracks inputs to the recursive call instead of outputs from it. Also, there's one tricky difference: the recursive version can go into the same subtree twice with different acc arguments. The batch version can't, so the batch version computes the max of all recursive calls to the same node. So for top-down traversals, a little bit of extra logic is needed to convert from expressions to batches. Otherwise, though, there's again a lot of shared code.

More generally, any bottom-up tree traversal becomes a loop through the nodes in order with a side table of recursive outputs, while any top-down tree traversal becomes a loop through the nodes in reverse order with a side table of recursive inputs. A recursive algorithm on expressions can be broken down into a sequence of bottom-up and top-down traversals in order to do it on batches.

Here's one realistic example: A batch can contain "zombie" nodes not referenced by the root node. (This can happen due to transformations, for example.) You can remove them through two traversals. First a top-down traversal marks all nodes reachable from the root. Then a bottom-up traversal copies the reachable nodes to a new vector. (The top-down depth code above, by the way, isn't correct in the presence of zombie nodes.)

This pattern of turning recursion into iteration means you might do more calls than you need to, but you'll never do duplicate calls. This is typically a win. It might also lead to less allocation, though I think in a language like Racket that impact is hard to predict. In something like C or Rust, where the nodes can be stored directly in the vector, the allocation impact will be real and large, and you'll also get potential benefits from things like cache locality and prefetching.

Performing rewrites

Recursive functions that return recursive data structures are a little tricker.

For example, suppose you want to write a function that finds and replaces all (+ a (+ b c)) with (+ (+ a b) c). Since the input is expressions and the output is expressions, this is a function from batch to batch. Since it is bottom-up, it's implemented with a for loop forward over the input batch.

Note that the cardinality of the input and output batches might differ. That's true even for this specific associativity rewrite, which replaces two nodes with two other nodes: the rewrite can cause there to be more or less sharing. In (* (+ 1 2) (+ 1 (+ 2 3))), the original expression's 3 addition nodes are all different but after rewriting the subexpression (+ 1 2) appears twice.

So to implement this rewrite, you need to have three vectors: the input batch; the output batch; and a reindexing vector that goes from the index of the input node in the input batch to index of the rewritten node in the output batch.

It looks something like this, where b1 is the input (immutable) batch, b2 is the output (mutable) batch, and reindex is the reindexing vector. The batch-add! method adds a node to a mutable batch and returns its index; it'll deduplicate the node when adding so the index might not be the end of the batch.

(for ([node (in-vector b1)] [idx (in-naturals)])
  (match node
    [(list '+ a rhs)
     (match (vector-ref b1 rhs)
       [(list '+ b c)
        (define idx-a (vector-ref reindex a))
        (define idx-b (vector-ref reindex b))
        (define idx-c (vector-ref reindex c))
        (define idx-a+b (batch-add! b2 (list '+ idx-a idx-b)))
        (define idx-lhs+c (batch-add! b2 (list '+ idx-a+b idx-c)))
        (vector-set! reindex idx idx-lhs+c)])])))

Note that in the pattern matching, you end up needing to index back into the input batch to figure out what the backreferences mean. Then, you need to translate the remaining backreferences using the reindexing vector before adding the rewritten nodes into the output batch. Also I'm not handling the else case, which ends up consuming many more lines of code.

The upshot is that doing this kind of rewrites is possible, but also kind of a pain. And dangerous! If you forget to reindex one of the backreferences you end up with garbage, and this is a very hard bug to debug.

Batchrefs

One solution to the danger here is types, or, since this is Racket, dynamic type checks.

The key abstraction is the "batchref", which is an index into a batch, but wrapped in a way that keeps track of the batch. You can deref a batchref to index into the batch and pull out "one level" of the node:

(struct batch-ref (batch idx))

(define (deref bref)
  (define node (vector-ref (batch-ref-batch bref) (batch-ref-idx bref)))
  (match node
    [(list op args ...)
     (list* op (map (curry batch-ref (batch-ref-batch bref)) args))]
    [other
     other]))

Note that after we pull out the node, we replace its back-references with batch-ref objects, to continue to enforce safety.

Then the pattern matching can be folded away like this:

(match (deref (batch-ref b1 idx))
  [(list '+ a (app deref (list '+ b c)))
   ...]
  [(list op args ...)
   (batch-add! b2 (list* op (map (curry vector-ref reindex) args)))]
  [other
   (batch-add! b2 other)])

Here's I've used deref together with a neat Racket feature to make the pattern matching nicer. First, I construct a batch-ref and deref it before entering the match; this replaces all my back-references with batch-ref objects.

Then, in the pattern match, I match (app deref ...), a Racket app pattern. This matches a batch-ref, calls deref on it, and then matches on the output; this lets me turn the nested match calls into a flat call, which also means I only need to write the else case once. app patterns are a nice feature of Racket, kind of analogous to "views" in Scala or Python.

In the main pattern matching case, a, b, and c are now batchrefs. The next idea is just to construct the replacement expression out of these batchrefs:

(match (deref (batch-ref b1 idx))
  [(list '+ a (app deref (list '+ b c)))
   (define new-expr (list '+ (list '+ a b) c))
   ...]
  ...)

The new-expr is an expression tree where the leaves are references into the input batch. We can then recursively add every node in that expression, translating any batchrefs from old to new batch using the reindexing table:

(define (batch-add-transformed! b2 reindex expr)
  (match expr
    [(batch-ref b1 idx)
     (vector-ref reindex idx)]
    [(list op args ...)
     (batch-add!
      b2
      (list* op (map (curry batch-add-transformed! b2 reindex) args)))]))

Then the actual transformation looks like this:

(match (deref (batch-ref b1 idx))
  [(list '+ a (app deref (list '+ b c)))
   (define new-idx (batch-add-transformed! b2 reindex (list '+ (list '+ a b) c)))
   (vector-set! reindex idx new-idx)]
  ...)

You can wrap this whole pattern into a helper function, so that the final invocation looks like this:

(batch-replace b1
  (lambda (node)
    (match node
      [`(+ ,a ,(app deref `(+ ,b ,c))) `(+ (+ ,a ,b) ,c)]
      [other other])))

The batch-replace helper handles all of the allocation, reindexing, and all of that. The node passed into the callback uses batch-ref for all the back-references, hence the deref, and also handles constructing the new batch and replacing all the indices. In Racket you can even write some custom syntax around this—we're still playing with what exactly that syntax should be and haven't yet settled on anything.

Batches of multiple expressions

A batch can also store multiple expressions. In Herbie this is extremely important because all of our expressions are quite similar. In one key part of Herbie we already do this, and the batch 25 times smaller than the expressions themselves.

A batch storing multiple expressions looks the same as a normal batch. But we now also want to keep track of where each initial expression is in the batch; we call that second vector the "rootvec". To go back from the batch to the expressions, you use the iterative algorithm above, and at the end read one expression for each entry of the rootvec. It's pretty fast.

As soon as you have multiple expressions, it becomes even easier to have nodes not referenced from one of the "root nodes". We call these "zombie" nodes and you have to "cull" them every now and then. If you think about it, the batch is kind of like an arena allocator for expressions, and zombie culling is just manual garbage collection. The same top-down-then-bottom-up pattern described earlier works. It's basically a mark-copy collector.

There's a bit of a trade-off here—if you have too many "zombie" nodes, you can waste time analyzing or rewriting those nodes, but if you "cull" the zombie nodes too often that also takes a lot of time. This is the normal garbage collection frequency tradeoff. So sometimes you make judgements calls about whether or not to delete those zombie nodes.

For example, Herbie's series expander rewrites pow(x, 1/2) into sqrt(x), because the series expander thinks about real semantics and it has custom sqrt code that's faster than its generic pow code. The rewrite itself is done just like above, using batch-replace. But it can leave a zombie 1/2 node in the batch, if the pow was the only expression that used 1/2. We could cull the zombies afterwards—but series-expanding a constant is very cheap, so it's better to just leave the zombie node around.

A type perspective

One big problem working with batches is that you end up with lots of vectors. A batch of multiple expressions is already two vectors (the batch and the rootvec). A rewrite involves reindexing vector and two different batches. Then you map over the batches to produce new vectors. Pretty quickly you have a dozen vectors and it's hard to remember which vector is parallel to which other vector, and which indices need to be transformed how. It's a mess.

I find it helpful to think about this from a type perspective. Imagine a vector not as a type vector<T> but as a type vector<α, T> where α represents the "index space". You can think of it as the vector length, but I prefer thinking of it as a phantom type, totally meaningless on its own. In practice, we represent the index space as a reference to a batch. Each batch is its own index space, and then if you compute some value for each node in a batch, that has the same index space as the batch.

If you have two vectors with the same length, like in the depth code above, then both vectors have the same index space:

batch : vector<α, Node>
depths : vector<α, Nat>

When two vectors have the same index space you can traverse them in parallel.

Alternatively, if you have two vectors of different lengths, you think of them as having different index spaces:

input-batch : vector<α, Node>
output-batch : vector<β, Node>

When you have a reindexing vector, it maps from one indexing space to another:

reindex : vector<α, β>

A batch-ref is a batch plus an index, which we think of as a member of the index space. So:

batch-ref<α> := vector<α, Node> × α
deref : vector<α, Node> × α → Node

That said, it's sometimes better to identify batch-ref<α> with α itself, that is, to think about a batch-ref as a kind of safe index type. I already mentioned that sometimes you want to represent the type α at runtime as a reference to a batch, so then I think of a batch-ref<α> as storing the batch because it's storing its type parameter (represented at runtime with a batch pointer) and storing the index because all it's actually doing is wrapping the index.

Keeping the types in mind reminds you transform vectors appropriately. So for example, if you have a vector of expressions, you can transform them into a single batch using a method:

exprs2batch : vector<α, Expr> → ∃ β. vector<α, β> × vector<β, Node>

Note that the β type parameter is existentially bound here, that is, you do not a priori know how big the output batch will be, but you can still safely access its elements because you have that rootvec to map from α to β.

Batches for multiple types

Batches make a lot of sense for expressions, but actually there are lots of different recursive data structures in Herbie: expressions, derivations, types, and so on. We probably want to use batches for all of them (though expressions are definitely most common and most performance-sensitive).

I find the types particularly eluminating in this case. The place to start is the concept of the μ type:

μ<F> := F<μ<F>>

This is not a helpful definition, so instead let's look at an example. An expression is either literal or an operator applied to some arguments:

Expr := Literal | Apply(Op, List<Expr>)

The "recursive shape" of this type can be described by this non-recursive, parameterized type:1 [1 By the way, in an earlier draft of this blog post I made a total mess of the type theory but James set me straight. Thanks, James!]

ExprF<T> := Literal | Apply(Op, List<T>)

The μ type then goes from this "functor" to the recursive structure, recovering Expr = μ<ExprF>. So how does this connect to batches?

Well, it answers, for one, what a "node" is, which I did by example above. A node is the algebra of an expression, parameterized by indices (backreferences). In other words, by vector<α, Node>, we mean:

vector<α, ExprF<α>>

The type ExprF<α> means "expression, but with all the backreferences replaced by indices into α". So expr2batch can be generalized into a function flatten like this:

flatten : vector<α, μ<F>> → ∃ β. vector<α, β> × vector<β, F<β>>

This is the kind of thing people make fun of Haskell for: one-line definitions that take three pages of explanation. Anyway, I write batch<β, F> for vector<β, F<β>>.

In Racket, we don't have all these types but there's a more basic dynamic programming option. Instead of μ types and all that jazz, we have a "recursor":

(define (expr-recursor expr f)
  (match expr
    [(list op args ...)
     (list* op (map f args))]
    [other other]))

This is like a recursive walk over expressions, except the recursive call is replaced with f. In type theory terms it is the fmap of the functor.

Then a batch stores its recursor, and deref works like this:

(define (deref bref)
  (match-define (batch-ref batch idx) bref)
  (match-define (batch nodes recursor) batch)
  (expr-recursor (vector-ref nodes idx)
                 (lambda (idx) (batch-ref batch idx))))

This is now generic over recursors, and other helper functions like batch-replace and zombie culling can also be made generic in this way.

Moreover, if you want to write a function that transforms one recursive structure to another (for example, type checking an expression) you just need to take care to create input and output batches with different recursors. In a typed language you wouldn't need to store the recursor—that would be done with and enforced by types—but in Racket this is still pretty ergonomic.

(Specifically, in Racket, helper functions like batch-replace that output a batch have a variant, batch-replace!, that output into a caller-supplied mutable batch. You can construct that mutable batch with whatever recursor you want. In Racket, having a mutable- and immutable- versions of the same data structure and the same methods is fairly common.)

It's also possible to have data structures that mix multiple types of recursive data structures. For example, if you have expressions and types, and the expressions contain type annotations, you really have something like this:

TypeF<T> := BaseT | FnT(T, T)
Type := μ<T>

ExprF<Type><T> := Var(Int) | App(T, T) | Lam(Int, Type, T)
Expr := μ<ExprF<Type>>

Note that the type algebra is standard but the expression algebra has parameter T for subexpressions but also a parameter Type for types. (Which I've written curried.) If you want to store a bunch of expressions (and their associated types) in batches, it'd look like this:

exprs : batch<α, ExprF<β>> = vector<α, ExprF<β><α>>
types : batch<β, TypeF> = vector<β, TypeF<β>>

In other words, the batch of expressions exprs stores expressions, except with subexpressions replaced by backreferences and with types replaced by references into the batch of types, which itself stores backreferences to itself. While this is obviously kind-of confusing you can keep it somewhat straight with types or recursors.

Conclusion

In Herbie we are making a big push to use batches everywhere we can. I think it will make Herbie faster and use less memory.

Of course, Herbie already has a lot of code (about 15k lines!) and the conversion will be slow, and right now we lack experience working with batches. The examples above show how to convert basic functions over expressions into ones that work on batches, and the abstractions like batch-refs and recusors show that it's possible to operate on a batches at a reasonably high level.

As my colleagues and I get more experience with batches, we hope to discover even more such abstractions and ultimately make batches easy to work with.

Footnotes:

1

By the way, in an earlier draft of this blog post I made a total mess of the type theory but James set me straight. Thanks, James!