メモ化 @ Haskell

MemoTrieのソースを読んで,だいたい理解した.

どうやら,内部では Tree を作って,その中に関数の評価値を保持するみたい.

まず,全ての入力に対応する出力を Tree の中に保存する.しかし,実際に計算はしない.

遅延評価を利用して,必要になったときだけ評価する.だから, Tree が必要以上に大きくなることはない.

値の読みだしは Tree を探索して,値を見つけたらそれをそのまま返す,未評価ならば,評価.

おそらく, Map を利用したメモ化とほとんど同じだが,MemoTrieのほうが 高階関数 っぽい気がする.

理解を確かめるために,簡単なMemoを書いてみた.

とりあえず,入力が整数である関数のメモ化を考える.

data MemoTree a = MT a (MemoTree a) (MemoTree a)

整数を[Bool]として捉えるので,二分木を作る.

tree :: (Bits a) => (a -> b) -> MemoTree b
tree f = tpart (f.unbits) []
where tpart :: ([Bool] -> a) -> [Bool] -> MemoTree a
tpart f xs = MT (f xs) (tpart (f.(False:)) xs) (tpart (f.(True:)) xs)

unbits は[Bool]を受けとって,整数を返す関数.

treeでfの関数値を保持する木を構成する.木の深さは無限だが,実際には必要な箇所だけ評価されるので,

心配はいらない.

こんなものが作られる.

tree f = MT (f []) (MT f [False] (MT ...) (MT ...))
(MT f [True]  (MT ...) (MT ...))

木から関数値を取り出す.

untree :: Bits a => MemoTree b -> a -> b
untree mt = utpart mt.bits
where utpart :: MemoTree a -> [Bool] -> a
utpart (MT n _ _) []         = n
utpart (MT _ f _) (False:xs) = utpart f xs
utpart (MT _ _ t) (True: xs) = utpart t xs

根から,パターンマッチングでさがす.普通の探索.

bitsは 整数を[Bool]に変換する関数.

関数のメモ化.木をつくって,そこから値を取り出す.

memo :: Bits a => (a -> b) -> (a -> b)
memo = untree.tree

こんな風につかう.

fib :: Integer -> Integer
fib 1 = 1
fib 2 = 1
fib n = mFib (n - 1) + mFib (n - 2)
mFib = memo fib

実行してみる.

*MemoTree> fib 100
354224848179261915075
it :: Integer
(0.01 secs, 527336 bytes)

と,しっかり,メモ化されているみたい.

全ソースコード

module MemoTree (memo) where
import Data.Bits (Bits, testBit, shiftL, shiftR, (.|.))
data MemoTree a = MT a (MemoTree a) (MemoTree a)
memo :: Bits a => (a -> b) -> (a -> b)
memo = untree.tree
-- make MemoTree
tree :: (Bits a) => (a -> b) -> MemoTree b
tree f = tpart (f.unbits) []
where tpart :: ([Bool] -> a) -> [Bool] -> MemoTree a
tpart f xs = MT (f xs) (tpart (f.(False:)) xs) (tpart (f.(True:)) xs)
-- read from MemoTree
untree :: Bits a => MemoTree b -> a -> b
untree mt = utpart mt.bits
where utpart :: MemoTree a -> [Bool] -> a
utpart (MT n _ _) []         = n
utpart (MT _ f _) (False:xs) = utpart f xs
utpart (MT _ _ t) (True: xs) = utpart t xs
bits :: Bits t => t -> [Bool]
bits  = []
bits x = testBit x  : bits (shiftR x 1)
unbits :: Bits t => [Bool] -> t
unbits [] = 
unbits (x:xs) = unbit x .|. shiftL (unbits xs) 1
where unbit False = 
unbit True  = 1
-- example of memoization
fib :: Integer -> Integer
fib 1 = 1
fib 2 = 1
fib n = mFib (n - 1) + mFib (n - 2)
mFib = memo fib

MemoTrieはもっと高度なことができるし,コードも綺麗だ.