メモ化 @ Haskell
コンテンツ
MemoTrieのソースを読んで,だいたい理解した.
どうやら,内部では Tree を作って,その中に関数の評価値を保持するみたい.
まず,全ての入力に対応する出力を Tree の中に保存する.しかし,実際に計算はしない.
遅延評価を利用して,必要になったときだけ評価する.だから, Tree が必要以上に大きくなることはない.
値の読みだしは Tree を探索して,値を見つけたらそれをそのまま返す,未評価ならば,評価.
おそらく, Map を利用したメモ化とほとんど同じだが,MemoTrieのほうが 高階関数 っぽい気がする.
理解を確かめるために,簡単なMemoを書いてみた.
とりあえず,入力が整数である関数のメモ化を考える.
data MemoTree a = MT a (MemoTree a) (MemoTree a)
整数を[Bool]として捉えるので,二分木を作る.
tree :: (Bits a) => (a -> b) -> MemoTree b tree f = tpart (f.unbits) [] where tpart :: ([Bool] -> a) -> [Bool] -> MemoTree a tpart f xs = MT (f xs) (tpart (f.(False:)) xs) (tpart (f.(True:)) xs)
unbits は[Bool]を受けとって,整数を返す関数.
treeでfの関数値を保持する木を構成する.木の深さは無限だが,実際には必要な箇所だけ評価されるので,
心配はいらない.
こんなものが作られる.
tree f = MT (f []) (MT f [False] (MT ...) (MT ...)) (MT f [True] (MT ...) (MT ...))
木から関数値を取り出す.
untree :: Bits a => MemoTree b -> a -> b untree mt = utpart mt.bits where utpart :: MemoTree a -> [Bool] -> a utpart (MT n _ _) [] = n utpart (MT _ f _) (False:xs) = utpart f xs utpart (MT _ _ t) (True: xs) = utpart t xs
根から,パターンマッチングでさがす.普通の探索.
bitsは 整数を[Bool]に変換する関数.
関数のメモ化.木をつくって,そこから値を取り出す.
memo :: Bits a => (a -> b) -> (a -> b) memo = untree.tree
こんな風につかう.
fib :: Integer -> Integer fib 1 = 1 fib 2 = 1 fib n = mFib (n - 1) + mFib (n - 2) mFib = memo fib
実行してみる.
*MemoTree> fib 100 354224848179261915075 it :: Integer (0.01 secs, 527336 bytes)
と,しっかり,メモ化されているみたい.
全ソースコード
module MemoTree (memo) where import Data.Bits (Bits, testBit, shiftL, shiftR, (.|.)) data MemoTree a = MT a (MemoTree a) (MemoTree a) memo :: Bits a => (a -> b) -> (a -> b) memo = untree.tree -- make MemoTree tree :: (Bits a) => (a -> b) -> MemoTree b tree f = tpart (f.unbits) [] where tpart :: ([Bool] -> a) -> [Bool] -> MemoTree a tpart f xs = MT (f xs) (tpart (f.(False:)) xs) (tpart (f.(True:)) xs) -- read from MemoTree untree :: Bits a => MemoTree b -> a -> b untree mt = utpart mt.bits where utpart :: MemoTree a -> [Bool] -> a utpart (MT n _ _) [] = n utpart (MT _ f _) (False:xs) = utpart f xs utpart (MT _ _ t) (True: xs) = utpart t xs bits :: Bits t => t -> [Bool] bits = [] bits x = testBit x : bits (shiftR x 1) unbits :: Bits t => [Bool] -> t unbits [] = unbits (x:xs) = unbit x .|. shiftL (unbits xs) 1 where unbit False = unbit True = 1 -- example of memoization fib :: Integer -> Integer fib 1 = 1 fib 2 = 1 fib n = mFib (n - 1) + mFib (n - 2) mFib = memo fib
MemoTrieはもっと高度なことができるし,コードも綺麗だ.
作成者 Toru Mano
最終更新時刻 2023-01-01 (c70d5a1)