Pavel Panchekha

By

Share under CC-BY-SA.

Scheduling Batch Computations

Last time, I talked about batches, my term for an arena-allocated, flattened form of an AST, which I think is going to be a key data structure in Herbie. That post covered basics and sketched out a type theory; in this post I want to dive deeper into computating with batches.

Batch Refresher & Index Safety

Here's just a quick review from last time. A batch stores a recursive data structure like an AST in a flattened form. The batch itself is a vector of "nodes", where a node is an AST node but with subexpressions replaced by integers that represent back-references in the same batch. In a low-level language like Rust, batches can lead to big speedups due to cache effects and reduced memory usage, but even in a high-level language like Racket, which Herbie is written in, batches can deduplicate common subexpressions, with are very common especially in Herbie.

The last post sketched out a basic type system for batches. Every batch gets a unique type-level variable called an "index space", which we write α, β, and so on and which we also type-pun as an integer. Then a batch is:

batch<α, F> = vector<F<α>>

Here F is a "node" type, parameterized by the type used for subexpressions. An actual expression is F<F<F<...>>>, which is typically written μ<F>.

The point of the α parameter is to provide "index safety": to keep track of which integer index points to which batch. The same index means different expressions in different batches, and by thinking of indices not as plain integers but as these "index space" newtypes, we can enforce a type discipline that prevents mistakes like using an index in the wrong batch. The actual implementation of this type parameter sort of depends on the language. It could be existentially bound, or a dependent type, or something else.

Another thing we often need to do is to compute some generic value for each expression in a batch. An example might be the cost of an expressions, or the output from evaluating it on some input, or its series expansion, or something like that. We call those data:

data<α, T> = vector<T>

Note that again we keep track of the index space at the type level. This keeps track of what batch the values were computed on, so we don't get confused once multiple batches are in play.

Finally, a lot of functions take in one batch and output another—think a "rewrite" operation, for example—which means they not only need to take in a batch and output a batch, but in the process they need to keep track of which input expressions turn into which output expressions. We call this a "mapping":1 [1 Maybe? Don't love the name.]

mapping<α, β> = vector<β>

Note that the mapping goes from one index space to another; the type discipline basically requires you to use mappings appropriately and thus guarantees index safety.

In total, I call these three types—batches, data, and mappings—different "columns", written col<α, X>, whose index space is α. You can think of this as subtyping or interfaces or whatever.

Besides enforcing index safety, the type discipline also provides this guarantee: if two columns have the same index space, then their vectors have the same underlying length, which means you can traverse them in parallel (because an index into one is a valid index into the other).

I've taken to calling a group of columns with the same index space a "table"; in this analogy a batch is like a "database" of expressions, and the type discipline provides a "schema". Note that unlike a SQL database, the "table" is emphemeral—it is a human-level construct to make sense of iteration, not a data structure in memory—but columns are actual vectors in memory.

Computing on columns

A computation on columns typically takes the form of iterating over all rows in a table and using them to fill in another column. The interesting case is a recursive traversal on an expression. For example, consider Herbie's cost function on expressions: the cost of an expression is the cost of the node plus the cost of all subexpressions, except that if statements are instead a max of the then and else branch. The actual cost numbers are customizable and live in a Herbie data structure called a platform. A sketch looks like this:

(define (cost platform expr)
  (let loop ([expr expr])
    (match expr
      [(? literal?) (platform-literal-cost platform)]
      [(? symbol?) (platform-literal-cost platform)]
      [(list 'if c t f)
       (+ (platform-if-cost platform)
          (loop c)
          (max (loop t)
               (loop f)))]
      [(approx s e)
       (loop e)]
      [(list op args ...)
       (apply + (platform-op-cost platform) (map loop args))]))

If you're not familiar with Racket, the let loop construct lets you define an inline recursive function and call it on some arguments.

To run this on a batch, we first construct an output data column:

(define (batch-cost platform batch) ; batch : batch<α, Impl>
  (define out (make-data batch)) ; out : data<α, float>
  ...)

Here the make-data command takes in a batch which serves as its index space and also defines the length of the column. I've added type annotations (which in Racket don't do anything except keep us sane); the Impl type is the relevant of Herbie's two expression types.

Now need to loop over the batch, computing the cost of each expression:

(define (batch-cost platform batch) ; batch : batch<α, Impl>
  ...
  (for ([i (in-naturals)] ; i : α
        [node (in-col batch)]) ; node : Impl<α>
    (define my-cost
     (match node
       [(? literal?) (platform-literal-cost platform)]
       [(? symbol?) (platform-literal-cost platform)]
       [(list 'if c t f) ; c, t, f : α
        (+ (platform-if-cost platform)
           (col-ref out c)
           (max (col-ref out t)
                (col-ref out f)))]
       [(approx s e) ; e : α
        (col-ref out e)]
       [(list op args ...) ; args : list<α>
        (apply + (platform-op-cost platform op)
               (map (curry col-ref out) args))]))
    (col-set! out i my-cost))
  out)

This is almost the same thing as above, except with every call to loop replaced by col-ref out. Here, col-ref and col-set! are just vector-ref and vector-set! but type checking index spaces:2 [2 Imagine the curly brace return value of col-set! meaning that it outputs that value, but as an effect. This notation is non-standard but will become necessary shortly.]

col-ref : col<α, T> -> α -> T
col-set! : col<α, T> -> α -> T -> { col<α, T> }

Following these types enforces index safety.

Dependency safety

But some errors are still possible. For example, if I forget the vector-set! call, I'll cause two issues. First, I won't save the costs anywhere. But second, I'll also read garbage data for arguments. I call this dependency safety: it's unsafe to read a value from a batch before that value is properly computed.

Dependency safety is not a problem for immutable columns, because presumably all their entries are already computed. But computing on batches necessarily requires mutating columns, and mutation introduces data dependencies that we have to track.

What matters for safety is ordering. This cost procedure is a bottom-up recursion (the costs of leaves are computed first) which is why we iterated through the nodes in order. And we know that nodes only contain backreferences. So we actually have something like these pre/post-conditions in the loop:

(define (batch-cost batch) ; batch : batch<α, Impl>
  (define out (make-data batch))
  ; out : data<α.none, float>
  (for ([i (in-naturals)] ; i : α
        [node (in-col batch)]) ; node : Impl<α.lt(i)>
    ; out : data<α.lt(i), float>
    (define my-cost ...)
    (col-set! out i my-cost)
    ; out : data<a.lte(i), float>
    )
  ; out : data<α, float>
  )

In other words, no entries of out are valid before we enter the loop. Then, at the start of an iteration, all entries before i are valid. Crucially, the node also contains indices before i, so those are valid indices into out. By calling vector-set! we set the i-th entry, making it valid, which ensures that at the end of the iteration, all entries before and including i are valid, which then is inductive. Thus, at the end of the loop, all entries are valid.

Note what logic we need to actually check safety here. We need indices, with their different types. We also need the type α.lt(i) and similar. Those types make the col-ref and col-set! methods safer:

col-ref : col<α.lt(i), T> -> α.lt(i) -> T
col-set! : col<α.lt(i), T> -> α.eq(i) -> T -> { col<α.lte(i), T> }

Note that col-set! changes the index space in the type from α.lt(i) to α.lte(i), indicating that, by setting the i-th element, that element is now presumed valid. This type disipline enforces the dependency constraint (that we only read from valid entries), requires us to write to the vector, and also prevents weirder issues like writing twice to the same entry.

Though note that as a type system, it's pretty weird! We have type state and functions that mutate the type of their argument. To make that sane you'd want col-set! to return the updated vector in a functional style, which is fine but then you'll want some kind of affine type system to avoid accidental copies. And, of course, you'd probably want a solver for ordering constraints (if i : α.lt<j> and j : α.lt<k>, then i : α.lt<k>; I think it'd be tedious to have to write conversion functions, which would act basically as a proof) and type simplification (α.lt(i).lt(j) should simplify to something? Maybe α.lt(min(i, j))?). I think you could do it statically, especially maybe in a language like Rust that provides affine types.

Just as one example of the weirdness, what type does a batch have? Above I wrote batch<α, F> is col<α, F<α>>, which is true enough. But if we want to be particular about dependencies, shouldn't it be

batch<α, F> = col<i : α, F<α.lt(i)>>

In other words we need a new kind of dependency just type col terms properly. The second argument to col would have to be a function:

col : (α : Type) -> (α -> Type) -> Type

This would be a huge pain to work with in the real world, and would need a hugely sophisticated type system. But, you know, possible in theory. (I would not do this weird thing and just have a special batch-ref function that did the right thing.) I think, done correctly, you should get a kind of erasability, where firstly you should be able to erase α.lt(i) to just α (and still have index safety) and then secondly you should be able to erase α to int and col<α, T> to vector<T> to actually execute the code.

More complex dependency safety examples

If you stopped and actually read Herbie's platform cost procedure, you'd have noticed that the sketch I gave above is too simple. To compute the cost of an Impl in Herbie, we keed to know its type, which is Herbie is called a repr, so the cost procedure is a top-down traversal, computing the repr, followed by the bottom-up cost computation:

(define (cost platform repr expr)
  (let loop ([repr repr] [expr expr])
    (match expr
      ...
      [(list op args ...)
       (define itypes (impl-info op 'itype))
       (define arg-costs
         (for/list ([arg (in-list args)]
                    [itype (in-list itypes)])
           (loop itype arg)))
       (apply + (platform-op-cost platform op)
              arg-costs)])))

I'm showing one branch of the match; you can fill in the rest. The itypes variable stores the types of each argument to op, so we traverse itypes and args together and make the appropriate recursive call.

To implement this in batches, we'd need two traversals:

(define (batch-cost platform batch)
  (define reprs (make-data batch))
  (for ([i (in-range (col-length batch) -1 -1)]
        [node (in-col-rev batch)])
    (define my-repr (col-ref reprs i))
    (match node
      ...
      [(list op args ...)
       (define itypes (impl-info op 'itype))
       (for ([arg (in-list args)]
             [itype (in-list itypes)])
         (col-set! reprs arg itype))]))

  (define out (make-data batch))
  (for ([i (in-naturals)]
        [node (in-col batch)])
    (define my-cost ...)
    (col-set! out i my-cost))
  out)

In this top-down traversal, reprs stores the type of each node, which is its argument to the loop procedure. So in this loop reprs has type col<α.gte(i), float> (we've already "called" the current node) and col-set! has an alternative type:

col-set! : col<α.gte(i), T> -> α.lt(i) -> T -> { col<α.gte(i), T> }

But hold on! In this case the type of the argument doesn't change. And that's correct: there's no guarantee that the node right before the current one is pointed to by the current one. It might be pointed to by some unrelated expression. Or, it might not be pointed to by any node at all—it could be a "zombie node" as discussed in the last post. So we need to initialize every node, at least, with some kind of default repr, so we need an init valid. Moreover, it's possible for two different iterations to col-set! the same index, for example if a subexpression appears twice in the tree. Then it's possible for it to be assigned two different reprs, so we'll need some kind of join function.

But then what type should the column have? Since every value is initialized, perhaps just col<α, T>, indicating that every index is valid? But then it seems like we'll lose all dependency safety, and would be allowed to read any value at all. But it can't just be col<α.gte(i)>, because then we lose the inductive property.

I think the best bet might be adding a new type of column that stores its init value and join method, and then adding a col-default! operation on these columns, with type:

col-default! : col<α.gt(i), T> -> α.eq(i) -> { col<α.gte(i), T>}

The behavior of the function is to join the current cell with the init value (assuming the init is a unit of join), and the user calls this at the start of every iteration to make sure the cell is valid.

Mutable and immutable

What all this is pointing at is that mutable columns are more complex than immutable ones.

An immutable column has a clear and simple type like col<α, T>, with no constraints on α. A mutable column, on the other hand, has some weird constraints on α, and has to do some kind of loop tracking how these constraints are updated. At the end of the loop, every cell is filled in, and the column becomes immutable.

So far so good—for data columns (and mappings). These have a fixed size. But batches are more complex. Batches are updated by adding expressions to them, which are then deduplicated. This means that a mutable batch, one that we're not done adding things to yet, has unknown total size.

Our constraint type system can represent this, actually. A mutable batch is a column, with only indices up to its current length valid:

mutable-batch<α, F> ◁ ∃ u : α, col<i : α.lt(u), F<α.lt(i)>>

I used an existentially-bound u here but you could be explicit that this is the length of the batch if your type system allowed it. (Our hypothetical type system is complex enough!) Then adding to a batch would do:

batch-add! :
  (∃ u : α, col<i : α.lt(u), F<α.lt(i)>>) -> 
  F<α.lt(u)> ->
  { ∃ v : α.gte(u), col<i : α.lt(v), F<α.lt(i)>> } *
  α.lt(v)

Note that the mutable batch initially has length u, and the node we add has to have all its backreferences less than u to make sure they are valid. Then the output has is a new length v, which is at least as big as u (not strictly bigger, because we could have added a duplicate node and not changed the length) and the mutable batch now has this new length. The index of the added node is less than the new length.

This might seem to be getting way too abstract to be useful, but sadly not. For example, Herbie has a pass called "reduction" which is called as part of Taylor expansion. Here's how it works:

  • It checks for patterns like log(exp(x)) and simplifies them to x. This is done via a bottom-up traversal.
  • But for arithmetic, it calls a gather-additive-terms function, which turns sums into canonical forms. (It does this to find cancellation opportunities.) There's a similar gather-multiplicative-terms function. Both are bottom-up traversals.
  • (reduce e) first calls reduce on all subexpressions of e, then passes the resulting expression to gather-additive-terms, then uses that to compute the reduction of e.

If we were to turn this into a batch-based procedure, we would have:

  • The input batch, type batch<α, F>
  • The output batch, type mutable-batch<β, F>
  • A mapping between them, type mapping<α, β>
  • A column of gather-additive-terms outputs (and same for gather-multiplicative-terms). Since these are called after recursive reduction, this column has index space β.3 [3 You can actually do it with index space α too, which is similarly complex but with some differences] Canonical forms of additive nodes contain backreferences so this column has type data<i : β, ATerms<β.lt(i)>>.

The type checking is complex:

  • Suppose we're currently on index i in the input batch, and the output batch has length u.
  • So when we reduce checks for patterns, it's looking at an F<α.lt(i)>
  • The mapping maps the backreferences α.lt(i) to, uhh, β.lt(u), I guess? Hold on tight, it gets worse.
  • This gives us a F<β.lt(u)>, which we pass to gather-additive-terms. Note that gather-additive-terms is, presumably, valid up to u, so it can look up the recursive calls it needs.
  • That returns a node,4 [4 Actually a subtree but, uhh, there's enough going on, ignore this bit.] which we add to the output batch, which now has length v, getting a new index j : β.lt(v).
  • Now we need to update the mapping. Before the update, it has type mapping<α.lt(i), β.lt(u)>. After, it has type mapping<α.lte(i), β.lt(v)>. Note that we need to update both parameters. The second parameter can be updated by simple subtyping, and then the first parameter is updated by col-set!.
  • Now the output batch and the mapping have been updated, but we don't yet have updates for gather-additive-terms for all the new nodes in the output batch. So we need to fill in that column from u to v.

Hope you followed that. It is deep into the thick of type-based programming! But you can also see the work the types are doing. For example, that last step, of updating gather-additive-terms, is not obvious. It seems like you probably only need to update it for the one new node j, right? Not so, because a later addition might for some reason deduplicate into a a node between u and j.

And by the way, the pretty complex schedule here is not a one-off. The overall Taylor series pipeline in Herbie has something similar, where it is regularly generating a node, adding it to a batch, simplifying it, checking if it simplifies to 0, and using that to determine which new nodes to generate. So in this case we'd have three batches (input, generated, and simplified) along with a whole bunch of mappings, data columns, and backreferences, all of which have to be updated in lockstep.

Conclusion

The previous post covered a simple type system that guarantees index safety for computations on batches. This post covered a much more complex type system that guarantees (I think?) dependency safety as well. This type system does provide a lot of leverage to avoid subtle bugs, but it is also pretty dang complicated, and programming with it feels like Haskell in that you're putting in a lot of sweat to make the type checker happy.

Next I'm going to be thinking about making this all easier to use using higher-level combinators. My dream is to separate the execution order, which as you can see is quite complex, from the actual meaning of the code, in a Halide style.

Footnotes:

1

Maybe? Don't love the name.

2

Imagine the curly brace return value of col-set! meaning that it outputs that value, but as an effect. This notation is non-standard but will become necessary shortly.

3

You can actually do it with index space α too, which is similarly complex but with some differences

4

Actually a subtree but, uhh, there's enough going on, ignore this bit.