Scheduling Batch Computations
Last time, I talked about batches, my term for an arena-allocated, flattened form of an AST, which I think is going to be a key data structure in Herbie. That post covered basics and sketched out a type theory; in this post I want to dive deeper into computating with batches.
Batch Refresher & Index Safety
Here's just a quick review from last time. A batch stores a recursive data structure like an AST in a flattened form. The batch itself is a vector of "nodes", where a node is an AST node but with subexpressions replaced by integers that represent back-references in the same batch. In a low-level language like Rust, batches can lead to big speedups due to cache effects and reduced memory usage, but even in a high-level language like Racket, which Herbie is written in, batches can deduplicate common subexpressions, with are very common especially in Herbie.
The last post sketched out a basic type system for batches. Every batch gets a unique type-level variable called an "index space", which we write α, β, and so on and which we also type-pun as an integer. Then a batch is:
batch<α, F> = vector<F<α>>
Here F
is a "node" type, parameterized by the type used for
subexpressions. An actual expression is F<F<F<...>>>
, which is
typically written μ<F>
.
The point of the α
parameter is to provide "index safety": to keep
track of which integer index points to which batch. The same index
means different expressions in different batches, and by thinking of
indices not as plain integers but as these "index space" newtypes, we
can enforce a type discipline that prevents mistakes like using an
index in the wrong batch. The actual implementation of this type
parameter sort of depends on the language. It could be existentially
bound, or a dependent type, or something else.
Another thing we often need to do is to compute some generic value for
each expression in a batch. An example might be the cost of an
expressions, or the output from evaluating it on some input, or its
series expansion, or something like that. We call those data
:
data<α, T> = vector<T>
Note that again we keep track of the index space at the type level. This keeps track of what batch the values were computed on, so we don't get confused once multiple batches are in play.
Finally, a lot of functions take in one batch and output another—think a "rewrite" operation, for example—which means they not only need to take in a batch and output a batch, but in the process they need to keep track of which input expressions turn into which output expressions. We call this a "mapping":1 [1 Maybe? Don't love the name.]
mapping<α, β> = vector<β>
Note that the mapping goes from one index space to another; the type discipline basically requires you to use mappings appropriately and thus guarantees index safety.
In total, I call these three types—batches, data, and
mappings—different "columns", written col<α, X>
, whose index space
is α
. You can think of this as subtyping or interfaces or whatever.
Besides enforcing index safety, the type discipline also provides this guarantee: if two columns have the same index space, then their vectors have the same underlying length, which means you can traverse them in parallel (because an index into one is a valid index into the other).
I've taken to calling a group of columns with the same index space a "table"; in this analogy a batch is like a "database" of expressions, and the type discipline provides a "schema". Note that unlike a SQL database, the "table" is emphemeral—it is a human-level construct to make sense of iteration, not a data structure in memory—but columns are actual vectors in memory.
Computing on columns
A computation on columns typically takes the form of iterating over
all rows in a table and using them to fill in another column. The
interesting case is a recursive traversal on an expression. For
example, consider Herbie's cost function on expressions: the cost of
an expression is the cost of the node plus the cost of all
subexpressions, except that if
statements are instead a max
of the
then and else branch. The actual cost numbers are customizable and
live in a Herbie data structure called a platform
. A sketch looks like
this:
(define (cost platform expr) (let loop ([expr expr]) (match expr [(? literal?) (platform-literal-cost platform)] [(? symbol?) (platform-literal-cost platform)] [(list 'if c t f) (+ (platform-if-cost platform) (loop c) (max (loop t) (loop f)))] [(approx s e) (loop e)] [(list op args ...) (apply + (platform-op-cost platform) (map loop args))]))
If you're not familiar with Racket, the let loop
construct lets you
define an inline recursive function and call it on some arguments.
To run this on a batch
, we first construct an output data
column:
(define (batch-cost platform batch) ; batch : batch<α, Impl> (define out (make-data batch)) ; out : data<α, float> ...)
Here the make-data
command takes in a batch
which serves as its index
space and also defines the length of the column. I've added type
annotations (which in Racket don't do anything except keep us sane);
the Impl
type is the relevant of Herbie's two expression types.
Now need to loop over the batch, computing the cost of each expression:
(define (batch-cost platform batch) ; batch : batch<α, Impl> ... (for ([i (in-naturals)] ; i : α [node (in-col batch)]) ; node : Impl<α> (define my-cost (match node [(? literal?) (platform-literal-cost platform)] [(? symbol?) (platform-literal-cost platform)] [(list 'if c t f) ; c, t, f : α (+ (platform-if-cost platform) (col-ref out c) (max (col-ref out t) (col-ref out f)))] [(approx s e) ; e : α (col-ref out e)] [(list op args ...) ; args : list<α> (apply + (platform-op-cost platform op) (map (curry col-ref out) args))])) (col-set! out i my-cost)) out)
This is almost the same thing as above, except with every call to loop
replaced by col-ref out
. Here, col-ref
and col-set!
are just
vector-ref
and vector-set!
but type checking index spaces:2 [2 Imagine
the curly brace return value of col-set!
meaning that it outputs that
value, but as an effect. This notation is non-standard but will become
necessary shortly.]
col-ref : col<α, T> -> α -> T col-set! : col<α, T> -> α -> T -> { col<α, T> }
Following these types enforces index safety.
Dependency safety
But some errors are still possible. For example, if I forget the
vector-set!
call, I'll cause two issues. First, I won't save the costs
anywhere. But second, I'll also read garbage data for arguments. I
call this dependency safety: it's unsafe to read a value from a batch
before that value is properly computed.
Dependency safety is not a problem for immutable columns, because presumably all their entries are already computed. But computing on batches necessarily requires mutating columns, and mutation introduces data dependencies that we have to track.
What matters for safety is ordering. This cost procedure is a bottom-up recursion (the costs of leaves are computed first) which is why we iterated through the nodes in order. And we know that nodes only contain backreferences. So we actually have something like these pre/post-conditions in the loop:
(define (batch-cost batch) ; batch : batch<α, Impl> (define out (make-data batch)) ; out : data<α.none, float> (for ([i (in-naturals)] ; i : α [node (in-col batch)]) ; node : Impl<α.lt(i)> ; out : data<α.lt(i), float> (define my-cost ...) (col-set! out i my-cost) ; out : data<a.lte(i), float> ) ; out : data<α, float> )
In other words, no entries of out
are valid before we enter the loop.
Then, at the start of an iteration, all entries before i
are valid.
Crucially, the node
also contains indices before i
, so those are valid
indices into out
. By calling vector-set!
we set the i
-th entry, making
it valid, which ensures that at the end of the iteration, all entries
before and including i
are valid, which then is inductive. Thus, at
the end of the loop, all entries are valid.
Note what logic we need to actually check safety here. We need
indices, with their different types. We also need the type α.lt(i)
and
similar. Those types make the col-ref
and col-set!
methods safer:
col-ref : col<α.lt(i), T> -> α.lt(i) -> T col-set! : col<α.lt(i), T> -> α.eq(i) -> T -> { col<α.lte(i), T> }
Note that col-set!
changes the index space in the type from α.lt(i)
to
α.lte(i)
, indicating that, by setting the i
-th element, that element
is now presumed valid. This type disipline enforces the dependency
constraint (that we only read from valid entries), requires us to
write to the vector, and also prevents weirder issues like writing
twice to the same entry.
Though note that as a type system, it's pretty weird! We have type
state and functions that mutate the type of their argument. To make
that sane you'd want col-set!
to return the updated vector in a
functional style, which is fine but then you'll want some kind of
affine type system to avoid accidental copies. And, of course, you'd
probably want a solver for ordering constraints (if i : α.lt<j>
and
j : α.lt<k>
, then i : α.lt<k>
; I think it'd be tedious to have to
write conversion functions, which would act basically as a proof) and
type simplification (α.lt(i).lt(j)
should simplify to something? Maybe
α.lt(min(i, j))
?). I think you could do it statically, especially
maybe in a language like Rust that provides affine types.
Just as one example of the weirdness, what type does a batch have?
Above I wrote batch<α, F>
is col<α, F<α>>
, which is true enough.
But if we want to be particular about dependencies, shouldn't it be
batch<α, F> = col<i : α, F<α.lt(i)>>
In other words we need a new kind of dependency just type col
terms
properly. The second argument to col
would have to be a function:
col : (α : Type) -> (α -> Type) -> Type
This would be a huge pain to work with in the real world, and would
need a hugely sophisticated type system. But, you know, possible in
theory. (I would not do this weird thing and just have a special
batch-ref
function that did the right thing.) I think, done correctly,
you should get a kind of erasability, where firstly you should be able
to erase α.lt(i)
to just α
(and still have index safety) and then
secondly you should be able to erase α
to int
and col<α, T>
to
vector<T>
to actually execute the code.
More complex dependency safety examples
If you stopped and actually read Herbie's platform cost procedure,
you'd have noticed that the sketch I gave above is too simple. To
compute the cost of an Impl
in Herbie, we keed to know its type, which
is Herbie is called a repr
, so the cost procedure is a top-down
traversal, computing the repr
, followed by the bottom-up cost
computation:
(define (cost platform repr expr) (let loop ([repr repr] [expr expr]) (match expr ... [(list op args ...) (define itypes (impl-info op 'itype)) (define arg-costs (for/list ([arg (in-list args)] [itype (in-list itypes)]) (loop itype arg))) (apply + (platform-op-cost platform op) arg-costs)])))
I'm showing one branch of the match
; you can fill in the rest. The
itypes
variable stores the types of each argument to op
, so we
traverse itypes
and args
together and make the appropriate recursive
call.
To implement this in batches, we'd need two traversals:
(define (batch-cost platform batch) (define reprs (make-data batch)) (for ([i (in-range (col-length batch) -1 -1)] [node (in-col-rev batch)]) (define my-repr (col-ref reprs i)) (match node ... [(list op args ...) (define itypes (impl-info op 'itype)) (for ([arg (in-list args)] [itype (in-list itypes)]) (col-set! reprs arg itype))])) (define out (make-data batch)) (for ([i (in-naturals)] [node (in-col batch)]) (define my-cost ...) (col-set! out i my-cost)) out)
In this top-down traversal, reprs
stores the type of each node, which
is its argument to the loop
procedure. So in this loop reprs
has type
col<α.gte(i), float>
(we've already "called" the current node) and
col-set!
has an alternative type:
col-set! : col<α.gte(i), T> -> α.lt(i) -> T -> { col<α.gte(i), T> }
But hold on! In this case the type of the argument doesn't change. And
that's correct: there's no guarantee that the node right before the
current one is pointed to by the current one. It might be pointed to
by some unrelated expression. Or, it might not be pointed to by any
node at all—it could be a "zombie node" as discussed in the last
post. So we need to initialize every node, at least, with some kind of
default repr
, so we need an init
valid. Moreover, it's possible for
two different iterations to col-set!
the same index, for example if a
subexpression appears twice in the tree. Then it's possible for it to
be assigned two different reprs, so we'll need some kind of join
function.
But then what type should the column have? Since every value is
initialized, perhaps just col<α, T>
, indicating that every index is
valid? But then it seems like we'll lose all dependency safety, and
would be allowed to read any value at all. But it can't just be
col<α.gte(i)>
, because then we lose the inductive property.
I think the best bet might be adding a new type of column that stores
its init
value and join
method, and then adding a col-default!
operation on these columns, with type:
col-default! : col<α.gt(i), T> -> α.eq(i) -> { col<α.gte(i), T>}
The behavior of the function is to join
the current cell with the init
value (assuming the init
is a unit of join
), and the user calls this
at the start of every iteration to make sure the cell is valid.
Mutable and immutable
What all this is pointing at is that mutable columns are more complex than immutable ones.
An immutable column has a clear and simple type like col<α, T>
, with
no constraints on α
. A mutable column, on the other hand, has some
weird constraints on α
, and has to do some kind of loop tracking how
these constraints are updated. At the end of the loop, every cell is
filled in, and the column becomes immutable.
So far so good—for data columns (and mappings). These have a fixed size. But batches are more complex. Batches are updated by adding expressions to them, which are then deduplicated. This means that a mutable batch, one that we're not done adding things to yet, has unknown total size.
Our constraint type system can represent this, actually. A mutable batch is a column, with only indices up to its current length valid:
mutable-batch<α, F> ◁ ∃ u : α, col<i : α.lt(u), F<α.lt(i)>>
I used an existentially-bound u
here but you could be explicit that
this is the length of the batch if your type system allowed it. (Our
hypothetical type system is complex enough!) Then adding to a batch
would do:
batch-add! : (∃ u : α, col<i : α.lt(u), F<α.lt(i)>>) -> F<α.lt(u)> -> { ∃ v : α.gte(u), col<i : α.lt(v), F<α.lt(i)>> } * α.lt(v)
Note that the mutable batch initially has length u
, and the node we
add has to have all its backreferences less than u
to make sure they
are valid. Then the output has is a new length v
, which is at least as
big as u
(not strictly bigger, because we could have added a duplicate
node and not changed the length) and the mutable batch now has this
new length. The index of the added node is less than the new length.
This might seem to be getting way too abstract to be useful, but sadly not. For example, Herbie has a pass called "reduction" which is called as part of Taylor expansion. Here's how it works:
- It checks for patterns like
log(exp(x))
and simplifies them tox
. This is done via a bottom-up traversal. - But for arithmetic, it calls a
gather-additive-terms
function, which turns sums into canonical forms. (It does this to find cancellation opportunities.) There's a similargather-multiplicative-terms
function. Both are bottom-up traversals. (reduce e)
first callsreduce
on all subexpressions ofe
, then passes the resulting expression togather-additive-terms
, then uses that to compute the reduction ofe
.
If we were to turn this into a batch-based procedure, we would have:
- The input batch, type
batch<α, F>
- The output batch, type
mutable-batch<β, F>
- A mapping between them, type
mapping<α, β>
- A column of
gather-additive-terms
outputs (and same forgather-multiplicative-terms
). Since these are called after recursive reduction, this column has index spaceβ
.3 [3 You can actually do it with index spaceα
too, which is similarly complex but with some differences] Canonical forms of additive nodes contain backreferences so this column has typedata<i : β, ATerms<β.lt(i)>>
.
The type checking is complex:
- Suppose we're currently on index
i
in the input batch, and the output batch has lengthu
. - So when we
reduce
checks for patterns, it's looking at anF<α.lt(i)>
- The mapping maps the backreferences
α.lt(i)
to, uhh,β.lt(u)
, I guess? Hold on tight, it gets worse. - This gives us a
F<β.lt(u)>
, which we pass togather-additive-terms
. Note thatgather-additive-terms
is, presumably, valid up tou
, so it can look up the recursive calls it needs. - That returns a node,4 [4 Actually a subtree but, uhh, there's enough
going on, ignore this bit.] which we add to the output batch,
which now has length
v
, getting a new indexj : β.lt(v)
. - Now we need to update the mapping. Before the update, it has type
mapping<α.lt(i), β.lt(u)>
. After, it has typemapping<α.lte(i), β.lt(v)>
. Note that we need to update both parameters. The second parameter can be updated by simple subtyping, and then the first parameter is updated bycol-set!
. - Now the output batch and the mapping have been updated, but we don't
yet have updates for
gather-additive-terms
for all the new nodes in the output batch. So we need to fill in that column fromu
tov
.
Hope you followed that. It is deep into the thick of type-based
programming! But you can also see the work the types are doing. For
example, that last step, of updating gather-additive-terms
, is not
obvious. It seems like you probably only need to update it for the one
new node j
, right? Not so, because a later addition might for some
reason deduplicate into a a node between u
and j
.
And by the way, the pretty complex schedule here is not a one-off. The
overall Taylor series pipeline in Herbie has something similar, where
it is regularly generating a node, adding it to a batch, simplifying
it, checking if it simplifies to 0
, and using that to determine which
new nodes to generate. So in this case we'd have three batches (input,
generated, and simplified) along with a whole bunch of mappings, data
columns, and backreferences, all of which have to be updated in
lockstep.
Conclusion
The previous post covered a simple type system that guarantees index safety for computations on batches. This post covered a much more complex type system that guarantees (I think?) dependency safety as well. This type system does provide a lot of leverage to avoid subtle bugs, but it is also pretty dang complicated, and programming with it feels like Haskell in that you're putting in a lot of sweat to make the type checker happy.
Next I'm going to be thinking about making this all easier to use using higher-level combinators. My dream is to separate the execution order, which as you can see is quite complex, from the actual meaning of the code, in a Halide style.
Footnotes:
Maybe? Don't love the name.
Imagine
the curly brace return value of col-set!
meaning that it outputs that
value, but as an effect. This notation is non-standard but will become
necessary shortly.
You can actually do it
with index space α
too, which is similarly complex but with some
differences
Actually a subtree but, uhh, there's enough going on, ignore this bit.