跳至內容

Haskell/延續過渡風格(CPS)

維基教科書,自由的教學讀本

延續過渡風格 (簡稱CPS) 是一種函數不直接返回值的代碼風格;在這種風格中,函數將結果傳入一個 延續 (continuation,指「之後的內容」),後者決定了之後的邏輯。本章將探討CPS能夠如何應用在Haskell中,以及如何使用Monad表達CPS。

延續是什麼?

[編輯]

回憶一下之前我們介紹 ($) 函數的時候:

> map ($ 2) [(2*), (4*), (8*)]
[4,8,16]

這段代碼並沒有什麼值得注意之處,雖然顯得有些古怪,因為我們並沒有使用通常的寫法 map (*2) [2, 4, 8]($) 函數使得代碼看起來倒置了,好像我們實際上是在把參數傳給函數,而不是將函數作用於參數上一樣。而這種看起來意義不明的倒置卻恰恰是CPS的核心! 從CPS的角度看來,($ 2) 是一個暫停中的計算: 也就是有著 (a -> r) -> r 類型的一個函數,後者接受另一個函數作為參數以產生最終結果。而這個具有 a -> r 類型的參數就是所謂的延續 (continuation);它指定了從當前函數到最終結果之間的邏輯。在上例中,列表中的函數被 map 作為延續,最終產生了三個不同的結果。值得注意的是,暫停中的計算和普通的值是可以相互轉化的。函數 flip ($) 將一個值轉化為暫停的計算 [1],而將 id 作為延續傳入暫停的計算即可得回原值。

這有什麼用呢?

[編輯]

CPS除了可以向新人炫技之外,還有更大的用途。它使得顯式的操作和改變程序的控制流成為可能。比如說,像命令式語言一樣,在一個函數全部執行完成之前返回一個值並跳出。異常和錯誤也能用CPS處理: 傳入兩個分別用於處理成功和失敗狀態的回調函數,並根據情況調用其中之一。我們還可以"暫停"一個計算,並在合適的時候使它繼續;或者我們可以實現簡單的並行計算(事實上,Hugs,一個Haskell解釋器,使用CPS來實現並行)。

在Haskell中,我們可以以類似的方式使用CPS來在Monad中實現一些控制流。通常,我們也可以用別的方法來實現這些控制結構,特別是當我們應用了惰性計算的時候。在一些場景中(例如,當一個結構複雜的返回值最終將被不被調用者使用時),通過消除一些對類型構造函數的模式匹配,CPS能夠在一定程度上改善代碼的性能;但是,這種工作或許能夠由一個足夠智能的編譯器代勞。[2]

傳遞(過渡)延續

[編輯]

我們可以修改我們的函數,使他們返回一個延續而不是一個普通的值。如下兩例。

pythagoras

[編輯]
Example
Example

例子: 一些普通而簡單的函數

-- 定义加法和平方函数:

add :: Int -> Int -> Int
add x y = x + y

square :: Int -> Int
square x = x * x

pythagoras :: Int -> Int -> Int
pythagoras x y = add (square x) (square y)


變幻成CPS後,pythagoras返回一個暫停中的計算:

Example
Example

例子: 一些簡單的CPS函數

-- 定义使用了CPS的加法和平方函数,
-- (add_cps 和 square_cps 实际上并不是严格的CPS函数
-- 他们只是类型正确罢了)

add_cps :: Int -> Int -> ((Int -> r) -> r)
add_cps x y = \k -> k (add x y)

square_cps :: Int -> ((Int -> r) -> r)
square_cps x = \k -> k (square x)

pythagoras_cps :: Int -> Int -> ((Int -> r) -> r)
pythagoras_cps x y = \k ->
 square_cps x $ \x_squared ->
 square_cps y $ \y_squared ->
 add_cps x_squared y_squared $ k


我們來看看 pythagoras_cps 是如何運作的:

  1. x 平方並將結果傳入延續 (\x_squared -> ...)
  2. y 平方並將結果傳入延續 (\y_squared -> ...)
  3. x_squaredy_squared 平方並將結果傳入頂層的延續中

我們可以在GHCi中實驗這段代碼,使用print函數作為延續:

*Main> pythagoras_cps 3 4 print
25

如果我們忽略 pythagoras_cps 類型中 (Int -> r) -> r 周圍的括號,並將其與 pythagoras 的類型對比,我們可以發現,延續實際上只是一個被柯里化的額外參數,這也是為什麼說我們往函數裡傳遞(過渡)了一個延續。

thrice

[編輯]
Example
Example

例子: 一個簡單的不使用CPS的高階函數

thrice :: (a -> a) -> a -> a
thrice f x = f (f (f x))


*Main> thrice tail "foobar"
"bar"

一個像thrice這樣的高階函數的CPS形式,與它的原本形式不同,接受的參數也是CPS形式的函數。因此,f :: a -> a 將變幻成 f_cps :: a -> ((a -> r) -> r);在這個例子中,則是 thrice_cps :: (a -> ((a -> r) -> r)) -> a -> ((a -> r) -> r)。我們可以用函數的類型作為實現的指引 - 我們將f 換成相對應的CPS形式的函數,並將這些延續一路傳遞下去。

Example
Example

例子: 一個簡單的使用CPS的高階函數

thrice_cps :: (a -> ((a -> r) -> r)) -> a -> ((a -> r) -> r)
thrice_cps f_cps x = \k ->
 f_cps x $ \fx ->
 f_cps fx $ \ffx ->
 f_cps ffx $ k


Cont monad

[編輯]

我們需要一種複合CPS函數的方法,最好能夠避免像剛剛那樣使用多層嵌套的lambda表達式。我們可以從一個將一個CPS函數應用到一個暫停的計算(即CPS版本的值)上的類似於 ($) 的函數開始。我們試試看能不能寫出他的類型:

chainCPS :: ((a -> r) -> r) -> (a -> ((b -> r) -> r)) -> ((b -> r) -> r)

(請讀者試一試自己實現這個函數。提示: 這個函數返回一個接受 b -> r 類型的延續的函數;然後,試著構造出一個有著合適類型的實現。)

實現:

chainCPS s f = \k -> s $ \x -> f x $ k

我們提供給暫停的計算 s 一個延續,後者用 f 返回一個新的暫停的計算;這個暫停的計算隨即將頂層的延續 k 傳遞進去。意料之中的是,這個實現和上一個例子中的多層嵌套lambda表達式看起來很像。

chainCPS 的類型是不是很眼熟?如果我們將 (a -> r) -> r 代換成 (Monad m) => m a,將 (b -> r) -> r 代換成 (Monad m) => m b,我們就得到了 (>>=) 的類型簽名。我們再熟悉不過的 flip ($) 在這裡其到了類似於 return 的作用: 它接受一個值並返回一個相對應的暫停的計算。嘿,我們定義了一個Monad![3] 我們現在只需要把暫停的計算包裹進一個類似 Cont r a 的代理類型就好了。

cont :: ((a -> r) -> r) -> Cont r a
runCont :: Cont r a -> (a -> r) -> r

Cont 的 instance Monad 實現和我們剛剛所講的基本一致,雖然在包裹函數上有細微差別:

instance Monad (Cont r) where
    return x = cont ($ x)
    s >>= f  = cont $ \c -> runCont s $ \x -> runCont (f x) c

這使得我們不必顯式傳遞延續,也就不用手寫嵌套的lambda表達式了。{{{1}}} 將一個暫停的計算傳入一個CPS函數中。最後,我們用 runCont 來提取出最終結果。如下例:

Example
Example

例子: 使用了 Cont Monad 的 pythagoras

-- 使用在transformers库中定义的Cont Monad
import Control.Monad.Trans.Cont

add_cont :: Int -> Int -> Cont r Int
add_cont x y = return (add x y)

square_cont :: Int -> Cont r Int
square_cont x = return (square x)

pythagoras_cont :: Int -> Int -> Cont r Int
pythagoras_cont x y = do
    x_squared <- square_cont x
    y_squared <- square_cont y
    add_cont x_squared y_squared


callCC

[編輯]

雖然我們自然地構造出了一個Monad,但是你也許會疑惑,因為我們之前曾提到過CPS可以實現程序的控制流。在將程序轉換成CPS形式後,我們將延續包裹在了Monad中,這使我們失去了實現控制流的靈活性。於是我們引入了 callCC 函數,它能在,且僅在我們需要時賦予我們對延續的直接控制。


callCC 是一個非常特殊的函數,我們將通過例子來逐漸引入:

Example
Example

例子: 使用了 callCCsquare

-- 不使用 callCC
square :: Int -> Cont r Int
square n = return (n ^ 2)

-- 使用 callCC
squareCCC :: Int -> Cont r Int
squareCCC n = callCC $ \k -> k (n ^ 2)


我們傳遞給 callCC 一個函數作為參數,後者返回一個暫停的計算 (即返回值的類型為 Cont r a),我們將這個返回值稱為"callCC 計算"。原則上callCC 計算就是整個 callCC 的返回值。最關鍵的地方,也即 callCC 的獨特之處,在於 k,傳入 callCC 的函數的所接收的參數。這是一個能使整個計算跳出的 彈射按鈕: 任何地方對它的調用 k x 將把 x 封裝成一個暫停的計算,然後隨著控制流返回到 callCC 被調用的地方。這是一個無條件跳轉;特別的,在 k 被調用之處往後的計算將被捨棄。從另一種角度看,k 獲得了 callCC 之後 剩餘的計算;對它的調用將把一個值傳入這剩餘的計算當中("callCC" 是 "call with current continuation" 的縮寫,即 "調用並傳入當前的延續")。雖然在上例中,k 所起的作用和 return 並無二致,callCC 向我們展示了一些全新的可能性。

決定什麼時候使用 k

[編輯]

callCC 使我們獲得了決定何時將何值傳入延續的能力。下例將展示它的運用。

Example
Example

例子: 我們的第一個完全使用了 callCC 能力的函數

foo :: Int -> Cont r String
foo x = callCC $ \k -> do
    let y = x ^ 2 + 3
    when (y > 20) $ k "大于20"
    return (show $ y - 4)


foo 是一個稍稍自找麻煩的計算輸入參數的平方加三的函數;如果結果大於 20,那麼我們的 callCC 函數(在本例中,即整個 foo 函數)將立即返回,並將 "大于二十" 裝入一個暫停的計算中,後者隨即被傳遞給 foo。如果不是,我們將結果減去4,用 show 轉換成一個字符串,然後和上一種情況一樣封裝入一個暫停的計算中。值得注意的是,k 在這裡起的作用就像命令式語言中的 'return' 語句 一般,即立即退出並返回一個函數。但在Haskell中,k 只是一個語言的一等公民 —— 函數 —— 罷了。 因此你可以做一些將它傳遞給 when,將它保存在 Reader 中,諸如此類的事。

自然,你可以在 do 代碼塊中使用 callCC:

Example
Example

例子: 更成熟的使用了do 代碼塊的 callCC 例子

bar :: Char -> String -> Cont r Int
bar c s = do
    msg <- callCC $ \k -> do
        let s0 = c : s
        when (s0 == "你好") $ k "他们说你好呀"
        let s1 = show s0
        return ("他们似乎在说: " ++ s1)
    return (length msg)


當你用一個值調用 k 時,整個對 callCC 的調用都得到了這個值。從最終效果上看,這使得 k 看起來與其他語言中的 'goto' 語句類似: 在我們的例子中,當我們調用 k 時,它將計算跳轉到 callCC 被調用的地方,也就是 msg <- callCC $ ... 這一行。對傳遞給 callCC 的 do 代碼塊的執行到此為止。下例中,有一個不會被執行的語句:

Example
Example

例子: 跳出一個函數,其中一行將不被執行

quux :: Cont r Int
quux = callCC $ \k -> do
    let n = 5
    k n
    return 25


quux 將返回 5,而不是 25,因為我們在到達 return 25 這一行之前就跳出了 quux

幕後

[編輯]

我們故意沒有使用以往的風格: 通常當我們引入一個函數時,我們會先給出它的類型;但在這個函數上我們決定換一條路。原因很簡單: 它的類型相當複雜,而且我們並不能從中一眼看出它的功能,或者它的實現。然而,在剛剛對 callCC 的使用方式進行了演示後,我們認為你已經準備好了。放輕鬆...

callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a

有了關於 callCC 的一些認識,我們可以試著解讀這個類型簽名。整個函數的返回值和作參數的函數的返回值類型必須相同(即 Cont r a),因為如果我們不調用 k 那麼作為參數的函數的返回值將直接被傳入 callCC。那麼,k 的類型是什麼呢?如同我們之前所說的那樣,k 將它的參數裝入一個暫停的計算中,然後控制流返回到 callCC 被調用之處;因此,若後者的類型的為 Cont r ak 的參數必須具有 a 的類型。有趣的是,只要它具有 Cont r b 的形式,k 返回值的類型(也就是 b 的類型)並無關緊要。這是因為以 a 為參數產生的暫停的計算將接受 callCC 之後的延續,而不是 k 的調用之後的。

註解

由於 k 返回值類型的形式,下例將產生一個類型錯誤

quux :: Cont r Int
quux = callCC $ \k -> do
   let n = 5
   when True $ k n
   k 25

k 的返回值並不被局限於某一個類型;然而,when 將它限制成 Cont r (),因此 k 25 的類型和 quux 的不符。我們只需要將末尾的 kreturn 代替,或使用 k () 就可以了。

在這個部分的最後,我們來看看 callCC 的實現。你能從中找到 k 嗎?

callCC f = cont $ \h -> runCont (f (\a -> cont $ \_ -> h a)) h

這段代碼也許有些難以理解。出乎意料的,ContcallCCreturn(>>=) 的實現都能從他們的類型中自動推導。Lennart Augustsson 的 Djinn [1] 就是這樣的一個自動推導程序。同樣參見 Phil Gossett 的 Google tech talk: [2] 以了解 Djinn 的理論背景;另外,參見 Dan Piponi 的文章: [3],其中講述了如何用 Djinn 推導CPS。

例子: 一個複雜的控制結構

[編輯]

我們來看一些更為複雜的控制流操作的樣例。第一個節選自 "The Continuation monad" All about monads tutorial,已獲得授權。

Example
Example

例子: 在複雜的控制結構中使用Cont Monad

{- 我们用Cont Monad来"跳出"代码块。
以下函数实现了一个复杂的控制结构以处理数字:

输入 (n)     输出                      列表里的内容
=========     ======                    ==========
0-9           n                         无
10-199        (n/2) 的数位数            (n/2) 的数位数
200-19999     n                         (n/2) 的数位数
20000-1999999 (n/2) 倒过来              无
>= 2000000    (n/2) 的数位之和          (n/2) 的数位数
-} 
fun :: Int -> String
fun n = (`runCont` id) $ do
    str <- callCC $ \exit1 -> do                            -- 定义 "exit1"
        when (n < 10) (exit1 (show n))
        let ns = map digitToInt (show (n `div` 2))
        n' <- callCC $ \exit2 -> do                         -- 定义 "exit2"
            when ((length ns) < 3) (exit2 (length ns))
            when ((length ns) < 5) (exit2 n)
            when ((length ns) < 7) $ do
                let ns' = map intToDigit (reverse ns)
                exit1 (dropWhile (=='0') ns')               -- 跳出两层结构
            return $ sum ns
        return $ "(ns = " ++ (show ns) ++ ") " ++ (show n')
    return $ "Answer: " ++ str


fun 接收一個 n 作為參數。它的實現使用了 ContcallCC 以構建一個控制結構,其中再次使用 ContcallCC ,依據 n 的範圍來做一些不同的事,正如開頭的注釋所說那樣。讓我們一步步看:

  1. 首先,處於代碼首層的 (`runCont` id) 僅僅是說我們將構造一個 Cont 塊並使用 id 作為延續(換句話說,我們不做改變地將值從暫停的計算中提取出來)。這是有必要的,因為 fun 的返回值類型中並沒有 Cont
  2. 我們將以下 callCC do 代碼塊的結果命名為 str:
    1. n 小於10,我們直接退出,返回 n 的字符串表示。
    2. 否則,我們構造列表 ns,其中包含了 n `div` 2 的各數位。
    3. 將以下 callCC do 代碼塊的結果(類型為 Int)命名為n'
      1. 如果 length ns < 3 成立,也就是說,n `div` 2 的數位數小於3,我們從此層退出,返回數位數。
      2. 如果 n `div` 2 的數位數小於5,我們從此層退出,返回 n
      3. 如果 n `div` 2 的數位數小於7,我們從此層以及外層退出,返回 n `div` 2 的倒序顯示的數位數(一個String)。
      4. 否則,我們從此層退出,返回 n `div` 2 的各數位之和。
    4. 我們從這個 do 代碼塊退出,返回 String "(ns = X) Y",其中 X 代表 nsn `div` 2 的各數位,Y 代表從內層 do 代碼塊中返回的結果,即 n'
  3. 最後,我們從整個函數中返回 "Answer: Z",其中 Z 代表我們從 callCC do 代碼塊中獲得的值。

例子: 異常

[編輯]

我們也可以用延續來處理異常。我們將使用 兩個 延續: 一個用於處理異常,另一個代表執行成功後的後續計算。下面這個簡單的函數將它的兩個參數作整數除法,若分母為零則產生異常。

Example
Example

例子: 能夠拋出異常的 div

divExcpt :: Int -> Int -> (String -> Cont r Int) -> Cont r Int
divExcpt x y handler = callCC $ \ok -> do
    err <- callCC $ \notOk -> do
        when (y == 0) $ notOk "分母为零"
        ok $ x `div` y
    handler err

{- For example,
runCont (divExcpt 10 2 error) id --> 5
runCont (divExcpt 10 0 error) id --> *** Exception: 分母为零
-}


它是如何工作的?我們使用了兩個嵌套的 callCC 調用。第一個給出了當一切正常時使用的延續;第二個則給出了當異常發生時所使用的延續。如果分母不為 0,x `div` y 被傳入 ok,計算跳回頂層的 divExcpt。但是,如果分母為 0,我們將一條錯誤信息傳給 notOk,後者將我們從內層的 do 代碼塊中彈出。我們給出的信息被命名為 err 並傳給 handler

下面是一個更為通用的異常處理函數。傳入一個暫停中的計算作為第一個參數(更精確的說,這是一個接收一個錯誤處理函數然後返回一個Cont Monad的函數),以及一個錯誤處理函數作為第二個參數。本例使用了 MonadCont 類型類 [4],其包含了 Cont 和相對應的 ContT transformer,以及對應的一系列 instance。

Example
Example

例子: 通用的 try ,使用了延續。

import Control.Monad.Cont

tryCont :: MonadCont m => ((err -> m a) -> m a) -> (err -> m a) -> m a
tryCont c h = callCC $ \ok -> do
    err <- callCC $ \notOk -> do
        x <- c notOk
        ok x
    h err


實際使用的例子 try :

Example
Example

例子: 使用 try

data SqrtException = LessThanZero deriving (Show, Eq)

sqrtIO :: (SqrtException -> ContT r IO ()) -> ContT r IO ()
sqrtIO throw = do 
    ln <- lift (putStr "输入一个需要开平方根的数: " >> readLn)
    when (ln < 0) (throw LessThanZero)
    lift $ print (sqrt ln)

main = runContT (tryCont sqrtIO (lift . print)) return


本例中,拋出異常意味著從 callCC 的代碼塊中彈出。sqrtIO 中的 throw 使得我們從 tryCont 內的 callCC 中跳出.

例子: 協程

[編輯]

本例中,我們定義一個 CoroutineT Monad 以實現 forkyield 函數。fork 將一個暫停的協程壓入隊列,yield 暫停當前的協程.

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-- 我们使用 GeneralizedNewtypeDeriving 以避免一些枯燥无味的代码。在GHC 7.8及之前的版本中,
-- 这个拓展满足 Safe Haskell。

import Control.Applicative
import Control.Monad.Cont
import Control.Monad.State

-- CoroutineT Monad 只是一个ContT嵌套的StateT,后者包含了暂停的协程。
newtype CoroutineT r m a = CoroutineT {runCoroutineT' :: ContT r (StateT [CoroutineT r m ()] m) a}
    deriving (Functor,Applicative,Monad,MonadCont,MonadIO)

-- 用以操作协程的队列。
getCCs :: Monad m => CoroutineT r m [CoroutineT r m ()]
getCCs = CoroutineT $ lift get

putCCs :: Monad m => [CoroutineT r m ()] -> CoroutineT r m ()
putCCs = CoroutineT . lift . put

-- 从队列弹出/向队列压入协程。
dequeue :: Monad m => CoroutineT r m ()
dequeue = do
    current_ccs <- getCCs
    case current_ccs of
        [] -> return ()
        (p:ps) -> do
            putCCs ps
            p

queue :: Monad m => CoroutineT r m () -> CoroutineT r m ()
queue p = do
    ccs <- getCCs
    putCCs (ccs++[p])

-- 接口。
yield :: Monad m => CoroutineT r m ()
yield = callCC $ \k -> do
    queue (k ())
    dequeue

fork :: Monad m => CoroutineT r m () -> CoroutineT r m ()
fork p = callCC $ \k -> do
    queue (k ())
    p
    dequeue

-- 恢复暂停的协程,直到队列为空。
exhaust :: Monad m => CoroutineT r m ()
exhaust = do
    exhausted <- null <$> getCCs
    if not exhausted
        then yield >> exhaust
        else return ()

-- 在上层的 Monad m 中运行协程。
runCoroutineT :: Monad m => CoroutineT r m r -> m r
runCoroutineT = flip evalStateT [] . flip runContT return . runCoroutineT' . (<* exhaust)

使用樣例:

printOne n = do
    liftIO (print n)
    yield

example = runCoroutineT $ do
    fork $ replicateM_ 3 (printOne 3)
    fork $ replicateM_ 4 (printOne 4)
    replicateM_ 2 (printOne 2)

輸出:

3
4
3
2
4
3
2
4
4

Template:Haskell/NotesSection

  1. \x -> ($ x),展開既得 \x -> \k -> k x
  2. attoparsec 是一個使用CPS以提高性能的例子。
  3. 練習: 驗證並證明它滿足Monad的性質。
  4. mtl 包,模塊 Template:Haskell lib.