Stream Fuse Carefully
The UW reading group in programming languages recently read “Exploiting vector instructions with generalized stream fusion”, from ICFP 20131 [1 By G. Mainland, R. Leshchinskiy, and S. Peyton Jones.]. The paper shows a form of stream fusion which can use special vector instructions in the CPU (also bulk-copy operations such as memcpy
). James and I noticed an interesting gotcha in using stream fusion, which we’re writing about here.
What is stream fusion?
Stream fusion is a way to avoid intermediate data structures when writing code in a functional style. Suppose we have the code
foo xs = sum $ filter even $ map floor $ xs
This code will be slow, because the map
and floor
invocations allocate a new list. Instead, it would be better to combine the iteration over the intermediate lists into one giant recursive function:
fooᵣ xs = fooᵣ' xs 0 where fooᵣ' (x:xs) s = let r = floor x in if even r then fooᵣ' xs (s + r) else fooᵣ' xs r fooᵣ' [] s = s
This code is faster and avoids allocation but is ugly to write. The functional version used many small functions together to achieve its results; our recursive version had to reimplement all of them.
Stream fusion achieves the same result as our recursive code by using clever versions of map
, filter
, and sum
. The map
, filter
, and sum
operations on lists become the mapₛ
, filterₛ
, and sumₛ
operations on streams. A stream is a stateful function which produces either elements of the stream or a special “skip” instruction. The skip instruction allows functions like filterₛ
skip certain elements. The Step
type is parameterized over the element type and also the state s
used by the stream generator.
data Step α s = Yield α s | Skip s | Done
A stream, then, internally has some state type (∃ s
), and contains a starting state s
and a function s → Step α s
which produces steps in the stream.
data Stream α = ∃ s · Stream {next: s → Step α s, start: s}
We can turn lists into streams and back:
stream :: [α] → Stream α stream l = Stream { next: iter, start: l } where iter (x:xs) = Yield x xs iter [] = Done unstream :: Stream α → [α] unstream (Stream { next: iter, start: s₀ }) = go iter s₀ where go iter s = case iter s of Yield elt s' → elt : go iter s' Skip s' → go iter s' Done → []
As you can imagine, we can write mapₛ
, filterₛ
, and sumₛ
as functions on Stream α
that do the expected thing. So we can now write
map f = unstream ∘ mapₛ f ∘ stream filter f = unstream ∘ filterₛ f ∘ stream sum f = sumₛ ∘ stream
The benefit of this encoding is that our original functional code becomes
foo xs = sumₛ ∘ stream $ unstream ∘ filterₛ even ∘ stream $ unstream ∘ mapₛ floor ∘ stream $ xs
Now, in this code the only intermediate data structures are produced by the stream ∘ unstream
blocks. Since these calls don’t do anything to our list elements, we can ignore them, using GHC’s RULES
directive:
{-# RULES "STREAM stream/unstream fusion" forall s. stream (unstream s) = s #-}
GHC will then eliminate the stream ∘ unstream
calls and then inline the mapₛ
, filterₛ
, and sumₛ
blocks. After optimizing away the Step
data structures, the same tight code as fooᵣ
will be produced. And all of this happens automatically, so our code is clean even though it runs quickly.
What’s the problem?
As implemented, stream fusion is a great optimization which achieves great performance. There’s no problem with it. But there is a subtle gotcha in implementing it. You see, stream ∘ unstream
is not exactly equal to the identity.
The unstream
function skips Skip
steps, because those don’t carry elements of the stream. The stream
function, since it operates on an intermediate list, cannot restore them. But when stream ∘ unstream
is replaced with the identity, those skip steps hang around.
This isn’t a problem for map
or filter
, because both of those skip when their input stream skips, without changing their internal state in any way. But some care needs to be taken to make sure that every stream function has this behavior.
A stream function might depend on the presence of skips by, for example, replacing them with previous elements of the stream. Of course, this is a constructed example; you wouldn’t accidentally write this code. But it demonstrates the principle:
fillₛ (Stream { next: iter, start: s₀ }) = Stream { next: go, start: (s₀, Nothing) } where go (s, prev) = case iter s of Yield x s' → Yield x (s', Just x) Skip s' → case prev of Nothing → Skip (s', prev) Just x → Yield x (s', prev) Done → Done
If the input stream yields elements of type α
, our state has type (s, Maybe α)
. The second part of the state is the most recent element of the input stream; it starts as Nothing
(because there isn’t yet a previous element), and is set to Just x
every time the input stream produces an x
. When the input stream skips an element, we output the previous element if one exists.
For example,
unstream $ fillₛ $ filterₛ even $ stream [1..7] ⇒ [2, 2, 4, 4, 6]
Filtering away the even elements results in a stream containing
skip ; yield 2 ; skip ; yield 4 ; skip ; yield 6
Then the fill function fills in the later two skips.
However, if we go ahead and define
fill = stream ∘ fillₛ ∘ unstream
we run into trouble. Without the rewrite rule,
fill $ filter even $ [1..7] ⇒ [2, 4, 6]
because the intermediate stream ∘ unstream
removes all of the skips from the stream. With the rewrite rule, we instead see
fill $ filter even $ [1..7] ⇒ [2, 2, 4, 4, 6]
So the use of the rewrite rule changes the meaning of our program. That’s a problem, because rewrite rules can be turned on or off by compiler flags or optimization level.
But would this actually show up?
Of course, fillₛ
is not a function one would normally define. Yet there are reasonable function where a type might lead to different behavior with and without the stream ∘ unstream
rewrite rule.
Consider the enumerate
function, which we’d normally write as
enumerate :: [α] → [(ℤ, α)] enumerate xs = zipwith (,) [0..] xs
enumerate
pairs each element of a list with its index in the list; it’s often handy
We can write this as a stream function:
enumerateₛ (Stream { next: iter, start: s₀ }) = Stream { next: go, start: (s₀, 0) } where go (s, n) = case iter s of Yield x s' → Yield (x, n) (s', n+1) Skip s' → Skip (s', n+1) -- should be (s', n) Done → Done enumerate = stream ∘ enumerateₛ ∘ unstream
There’s a bug in the Skip
case, where we increment the index even though we have not seen a new element. Thanks to this bug, we have
enumerate $ filter even $ [1..7] ⇒ [(0, 2), (1, 4), (2, 6)]
without the rewrite rule, but as soon as we enable optimizations,
enumerate $ filter even $ [1..7] ⇒ [(1, 2), (3, 4), (5, 6)]
Note that the indices now track in the index in the original list, not in the intermediate list that has been optimized away.
Of course, this pitfall has been carefully avoided by the authors of the Haskell stream fusion libraries. But if you’re writing your own, watch out: this bug is easy to miss and at times easy to make. I’d write a QuickCheck property that all of your stream functions really do respect stream (unstream s) = s
.
Footnotes:
By G. Mainland, R. Leshchinskiy, and S. Peyton Jones.