sumsqの最適化の例

 Stream Fusionの利点の一つは,GHCが提供する様々な最適化機能との親和性が高いことです。リストの各要素を2乗してそれらの和を求めるプログラムであるsumsqを例に考えてみましょう(参考リンク)。sumsqは以下のように定義されているとします。

sumsq :: [Int] -> Int
sumsq = sum . map (\x -> x * x)

 ここで用いたData.List.Streamモジュールのsum関数は,以下のように定義されています。

-- | /O(n)/, /fusion/. The 'sum' function computes the sum of a finite list of numbers.
sum :: Num a => [a] -> a
sum l = sum' l 0
  where
    sum' []     a = a
    sum' (x:xs) a = sum' xs (a+x)
{-# NOINLINE [1] sum #-}

~ 略 ~

{-# RULES
"sum -> fusible"  [~1] forall xs.
    sum xs = Stream.sum (stream xs)
--"sum -> unfused" [1] forall xs.
--    Stream.sum (stream xs) = sum xs
  #-}

 この定義の中にある"sum -> fusible"規則によって作られる式「Stream.sum (stream xs)」で使われている,Data.Streamモジュールのsum関数(Stream.sum)の定義は以下の通りです。

sum :: Num a => Stream a -> a
sum (Stream next s0) = loop_sum 0 s0
  where
    loop_sum !a !s = case next s of   -- note: strict in the accumulator!
      Done       -> a
      Skip    s' -> expose s' $ loop_sum a s'
      Yield x s' -> expose s' $ loop_sum (a + x) s'
{-# INLINE [0] sum #-}

 これら二つのsum関数の定義を基に,GHCがsumsqをどのように最適化するかを見ていきましょう。

 まず書き換え規則により,sumsqの定義は以下のように書き換えられます。

    sumsq = sum . map (\x -> x * x)
==> { 「.」のインライン化
      ただし「sumsq = \xs ->」ではなく,「sumsq xs =」とする }
    sumsq xs = sum (map (\x -> x * x) xs)
==> { "map -> fusible"規則の適用 }
    sumsq xs =
      sum (unstream 
          (Stream.map g (stream xs)))
==> { "sum -> fusible"規則の適用 }
    sumsq xs =
      Stream.sum
          (stream (unstream
          (Stream.map (\x -> x * x) (stream xs))))
==> { "STREAM stream/unstream fusion"規則の適用 }
    sumsq xs =
      Stream.sum
          (Stream.map (\x -> x * x) (stream xs))

 次に,「Stream.sum (Stream.map (\x -> x * x) (stream xs))」という式で使われている各関数がインライン化されます。

    sumsq xs = Stream.sum (Stream.map (\x -> x * x) (stream xs))
==> { stream関数のインライン化 }
    sumsq xs = Stream.sum (Stream.map (\x -> x * x) (Stream next0 (L xs)))
      where
        next0 (L [])     = Done
        next0 (L (x:xs)) = Yield x (L xs)
==> { Stream.map関数のインライン化 }
    sumsq xs = Stream.sum (Stream next0 (L xs))
      where
        next0 (L [])     = Done
        next0 (L (x:xs)) = Yield x (L xs)

        next1 !s = case next0 s of
            Done       -> Done
            Skip    s' -> Skip        s'
            Yield x s' -> Yield ((\x -> x * x) x) s'
==> { Stream.sum関数のインライン化 }
    sumsq xs = loop_sum 0 (L xs)
      where
        next0 (L [])     = Done
        next0 (L (x:xs)) = Yield x (L xs)

        next1 !s = case next0 s of
            Done       -> Done
            Skip    s' -> Skip        s'
            Yield x s' -> Yield ((\x -> x * x) x) s'

        loop_sum !a !s = case next1 s of
            Done       -> a
            Skip    s' -> expose s' $ loop_sum a s'
            Yield x s' -> expose s' $ loop_sum (a + x) s'

 各関数のインライン化により,next関数やloop_sum関数もsumsqの内部関数として取り込まれます。ただし,stream関数やData.Streamモジュールのmap関数では,nextという同じ名前を持つ内部関数を利用しています。このためGHCでは,二つのnext関数の定義が衝突するのを防ぐために,それぞれの関数名を変更します。同じ名前の変数や関数が衝突するのを防ぐために名前を付け替える処理を「α変換(alpha renaming,あるいはalpha conversion)」といいます。

 sumsqの定義に内部関数が加わったので,今度は内部関数の定義をインライン化します。

    sumsq xs = loop_sum 0 (L xs)
      where
        next0 (L [])     = Done
        next0 (L (x:xs)) = Yield x (L xs)

        next1 !s = case next0 s of
            Done       -> Done
            Skip    s' -> Skip        s'
            Yield x s' -> Yield ((\x -> x * x) x) s'

        loop_sum !a !s = case next1 s of
            Done       -> a
            Skip    s' -> expose s' $ loop_sum a s'
            Yield x s' -> expose s' $ loop_sum (a + x) s'
==> { next1関数のインライン化 }
    sumsq xs = loop_sum 0 (L xs)
      where
        next0 (L [])     = Done
        next0 (L (x:xs)) = Yield x (L xs)

        loop_sum !a !s =
          case (case next1 !s of
                  Done       -> Done
                  Skip    s' -> Skip        s'
                  Yield x s' -> Yield ((\x -> x * x) x) s') of
            Done       -> a
            Skip    s' -> expose s' $ loop_sum a s'
            Yield x s' -> expose s' $ loop_sum (a + x) s'

 このとき,loop_sumにはネストしたcase式が現れます。最適化の過程でこのような入れ子になったcase式が出てきた場合,GHCはより単純な形への変換を試みます。

 loop_sumでは,case式の中にcase式が登場しているので,内部のcase式をcase式の外側に展開して単純なcase式に書き換えます。このような処理を「case式のcase式に対する変換(case-of-case transformation)」と呼びます(参考リンク1参考リンク2)。

loop_sum !a !s =
   case next1 !s of
     Done       -> case Done of
            Done       -> a
            Skip    s' -> expose s' $ loop_sum a s'
            Yield x s' -> expose s' $ loop_sum (a + x) s'
     Skip    s' -> Skip        s' of
            Done       -> a
            Skip    s' -> expose s' $ loop_sum a s'
            Yield x s' -> expose s' $ loop_sum (a + x) s'
     Yield x s' -> Yield ((\x -> x * x) x) s') of
            Done       -> a
            Skip    s' -> expose s' $ loop_sum a s'
            Yield x s' -> expose s' $ loop_sum (a + x) s'

 入れ子になったcase式では,外側のcase式に存在する値によって,内側のcase式でどの分岐が行われるかが決まります。このように,入れ子になったcase式でどの分岐が行われるかが明らかなパターンでは,内側のcase式が除去された形に書き換えられます。この処理を「既知の構成子に対するcase式の変換(case-of-known-constructor transformation)」と呼びます。GHCでは「既知の構成子」の代わりに「既知の分岐(known-branch)」という用語を使うこともあります(参考リンク1参考リンク2)。

loop_sum !a !s =
   case next1 !s of
     Done       -> a
     Skip    s' -> expose s' $ loop_sum a s'
     Yield x s' -> expose s' $ loop_sum (a + ((\x -> x * x) x)) s'

 なお「既知の構成子に対するcase式の変換」では,必ずしも内側のcase式が具体的な値に対応する必要はありません。内側のcase式が変数に対するものであっても,以下に示すように内側のcase式でどの分岐が行われるかが明らかであれば,内側のcase式が除去されます。

    case next1 !s of
      Done       -> case x of
             Done       -> a
             Skip    s' -> expose s' $ loop_sum a s'
             Yield x s' -> expose s' $ loop_sum (a + x) s'
==> { 既知の構成子に対するcase式の変換 }
    case next1 !s of
      Done       -> a

 なお,「case式のcase式に対する変換」と「既知の構成子に対するcase式の変換」をまとめて「case式のcase式に対する変換」と呼ぶこともあります(参考リンク)。

 広義の「case式のcase式に対する変換」により,入れ子になったcase式は単純なcase式に変換されます。インライン化と広義の「case式のcase式に対する変換」の適用を繰り返すことで,sumsqはより単純な定義に書き換えられます。

    sumsq xs = loop_sum 0 (L xs)
      where
        next0 (L [])     = Done
        next0 (L (x:xs)) = Yield x (L xs)

        loop_sum !a !s =
          case (case next0 !s of
                  Done       -> Done
                  Skip    s' -> Skip        s'
                  Yield x s' -> Yield ((\x -> x * x) x) s') of
            Done       -> a
            Skip    s' -> expose s' $ loop_sum a s'
            Yield x s' -> expose s' $ loop_sum (a + x) s'
==> { case式のcase式に対する変換 }
    sumsq xs = loop_sum 0 (L xs)
      where
        next0 (L [])     = Done
        next0 (L (x:xs)) = Yield x (L xs)

        loop_sum !a !s =
          case next0 !s of
            Done       -> a
            Skip    s' -> expose s' $ loop_sum a s'
            Yield x s' -> expose s' $ loop_sum (a + ((\x -> x * x) x)) s'
==> { next0関数のインライン化 }
    sumsq xs = loop_sum 0 (L xs)
      where
        loop_sum !a !s =
          case (case !s of
                  (L [])     -> Done
                  (L (x:xs)) -> Yield x (L xs)) of
            Done       -> a
            Skip    s' -> expose s' $ loop_sum a s'
            Yield x s' -> expose s' $ loop_sum (a + ((\x -> x * x) x)) s'
==> { case式のcase式に対する変換 }
    sumsq xs = loop_sum 0 (L xs)
      where
        loop_sum !a !s =
          case !s of
            (L [])     -> a
            (L (x:xs)) -> expose (L xs) $ loop_sum (a + ((\x -> x * x) x)) (L xs)

 このように,Stream Fusionを使ったコードは,GHCの最適化機能によって効率が高いコードに書き換えられます。ただし,Stream Fusionを使ったコードの最適化がうまくいくのは「Stream型を引数として取り,Stream以外の型を返す関数」が途中で再帰関数を呼んでいないことが条件です。前述したように,再帰関数はインライン化されないからです。