跳转到内容

Haskell/延续过渡风格(CPS)

维基教科书,自由的教学读本

延续过渡风格 (简称CPS) 是一种函数不直接返回值的代码风格;在这种风格中,函数将结果传入一个 延续 (continuation,指“之后的内容”),后者决定了之后的逻辑。本章将探讨CPS能够如何应用在Haskell中,以及如何使用Monad表达CPS。

延续是什么?

[编辑]

回忆一下之前我们介绍 ($) 函数的时候:

> map ($ 2) [(2*), (4*), (8*)]
[4,8,16]

这段代码并没有什么值得注意之处,虽然显得有些古怪,因为我们并没有使用通常的写法 map (*2) [2, 4, 8]($) 函数使得代码看起来倒置了,好像我们实际上是在把参数传给函数,而不是将函数作用于参数上一样。而这种看起来意义不明的倒置却恰恰是CPS的核心! 从CPS的角度看来,($ 2) 是一个暂停中的计算: 也就是有着 (a -> r) -> r 类型的一个函数,后者接受另一个函数作为参数以产生最终结果。而这个具有 a -> r 类型的参数就是所谓的延续 (continuation);它指定了从当前函数到最终结果之间的逻辑。在上例中,列表中的函数被 map 作为延续,最终产生了三个不同的结果。值得注意的是,暂停中的计算和普通的值是可以相互转化的。函数 flip ($) 将一个值转化为暂停的计算 [1],而将 id 作为延续传入暂停的计算即可得回原值。

这有什么用呢?

[编辑]

CPS除了可以向新人炫技之外,还有更大的用途。它使得显式的操作和改变程序的控制流成为可能。比如说,像命令式语言一样,在一个函数全部执行完成之前返回一个值并跳出。异常和错误也能用CPS处理: 传入两个分别用于处理成功和失败状态的回调函数,并根据情况调用其中之一。我们还可以"暂停"一个计算,并在合适的时候使它继续;或者我们可以实现简单的并行计算(事实上,Hugs,一个Haskell解释器,使用CPS来实现并行)。

在Haskell中,我们可以以类似的方式使用CPS来在Monad中实现一些控制流。通常,我们也可以用别的方法来实现这些控制结构,特别是当我们应用了惰性计算的时候。在一些场景中(例如,当一个结构复杂的返回值最终将被不被调用者使用时),通过消除一些对类型构造函数的模式匹配,CPS能够在一定程度上改善代码的性能;但是,这种工作或许能够由一个足够智能的编译器代劳。[2]

传递(过渡)延续

[编辑]

我们可以修改我们的函数,使他们返回一个延续而不是一个普通的值。如下两例。

pythagoras

[编辑]
Example
Example

例子: 一些普通而简单的函数

-- 定义加法和平方函数:

add :: Int -> Int -> Int
add x y = x + y

square :: Int -> Int
square x = x * x

pythagoras :: Int -> Int -> Int
pythagoras x y = add (square x) (square y)


变幻成CPS后,pythagoras返回一个暂停中的计算:

Example
Example

例子: 一些简单的CPS函数

-- 定义使用了CPS的加法和平方函数,
-- (add_cps 和 square_cps 实际上并不是严格的CPS函数
-- 他们只是类型正确罢了)

add_cps :: Int -> Int -> ((Int -> r) -> r)
add_cps x y = \k -> k (add x y)

square_cps :: Int -> ((Int -> r) -> r)
square_cps x = \k -> k (square x)

pythagoras_cps :: Int -> Int -> ((Int -> r) -> r)
pythagoras_cps x y = \k ->
 square_cps x $ \x_squared ->
 square_cps y $ \y_squared ->
 add_cps x_squared y_squared $ k


我们来看看 pythagoras_cps 是如何运作的:

  1. x 平方并将结果传入延续 (\x_squared -> ...)
  2. y 平方并将结果传入延续 (\y_squared -> ...)
  3. x_squaredy_squared 平方并将结果传入顶层的延续中

我们可以在GHCi中实验这段代码,使用print函数作为延续:

*Main> pythagoras_cps 3 4 print
25

如果我们忽略 pythagoras_cps 类型中 (Int -> r) -> r 周围的括号,并将其与 pythagoras 的类型对比,我们可以发现,延续实际上只是一个被柯里化的额外参数,这也是为什么说我们往函数里传递(过渡)了一个延续。

thrice

[编辑]
Example
Example

例子: 一个简单的不使用CPS的高阶函数

thrice :: (a -> a) -> a -> a
thrice f x = f (f (f x))


*Main> thrice tail "foobar"
"bar"

一个像thrice这样的高阶函数的CPS形式,与它的原本形式不同,接受的参数也是CPS形式的函数。因此,f :: a -> a 将变幻成 f_cps :: a -> ((a -> r) -> r);在这个例子中,则是 thrice_cps :: (a -> ((a -> r) -> r)) -> a -> ((a -> r) -> r)。我们可以用函数的类型作为实现的指引 - 我们将f 换成相对应的CPS形式的函数,并将这些延续一路传递下去。

Example
Example

例子: 一个简单的使用CPS的高阶函数

thrice_cps :: (a -> ((a -> r) -> r)) -> a -> ((a -> r) -> r)
thrice_cps f_cps x = \k ->
 f_cps x $ \fx ->
 f_cps fx $ \ffx ->
 f_cps ffx $ k


Cont monad

[编辑]

我们需要一种复合CPS函数的方法,最好能够避免像刚刚那样使用多层嵌套的lambda表达式。我们可以从一个将一个CPS函数应用到一个暂停的计算(即CPS版本的值)上的类似于 ($) 的函数开始。我们试试看能不能写出他的类型:

chainCPS :: ((a -> r) -> r) -> (a -> ((b -> r) -> r)) -> ((b -> r) -> r)

(请读者试一试自己实现这个函数。提示: 这个函数返回一个接受 b -> r 类型的延续的函数;然后,试着构造出一个有着合适类型的实现。)

实现:

chainCPS s f = \k -> s $ \x -> f x $ k

我们提供给暂停的计算 s 一个延续,后者用 f 返回一个新的暂停的计算;这个暂停的计算随即将顶层的延续 k 传递进去。意料之中的是,这个实现和上一个例子中的多层嵌套lambda表达式看起来很像。

chainCPS 的类型是不是很眼熟?如果我们将 (a -> r) -> r 代换成 (Monad m) => m a,将 (b -> r) -> r 代换成 (Monad m) => m b,我们就得到了 (>>=) 的类型签名。我们再熟悉不过的 flip ($) 在这里其到了类似于 return 的作用: 它接受一个值并返回一个相对应的暂停的计算。嘿,我们定义了一个Monad![3] 我们现在只需要把暂停的计算包裹进一个类似 Cont r a 的代理类型就好了。

cont :: ((a -> r) -> r) -> Cont r a
runCont :: Cont r a -> (a -> r) -> r

Cont 的 instance Monad 实现和我们刚刚所讲的基本一致,虽然在包裹函数上有细微差别:

instance Monad (Cont r) where
    return x = cont ($ x)
    s >>= f  = cont $ \c -> runCont s $ \x -> runCont (f x) c

这使得我们不必显式传递延续,也就不用手写嵌套的lambda表达式了。{{{1}}} 将一个暂停的计算传入一个CPS函数中。最后,我们用 runCont 来提取出最终结果。如下例:

Example
Example

例子: 使用了 Cont Monad 的 pythagoras

-- 使用在transformers库中定义的Cont Monad
import Control.Monad.Trans.Cont

add_cont :: Int -> Int -> Cont r Int
add_cont x y = return (add x y)

square_cont :: Int -> Cont r Int
square_cont x = return (square x)

pythagoras_cont :: Int -> Int -> Cont r Int
pythagoras_cont x y = do
    x_squared <- square_cont x
    y_squared <- square_cont y
    add_cont x_squared y_squared


callCC

[编辑]

虽然我们自然地构造出了一个Monad,但是你也许会疑惑,因为我们之前曾提到过CPS可以实现程序的控制流。在将程序转换成CPS形式后,我们将延续包裹在了Monad中,这使我们失去了实现控制流的灵活性。于是我们引入了 callCC 函数,它能在,且仅在我们需要时赋予我们对延续的直接控制。


callCC 是一个非常特殊的函数,我们将通过例子来逐渐引入:

Example
Example

例子: 使用了 callCCsquare

-- 不使用 callCC
square :: Int -> Cont r Int
square n = return (n ^ 2)

-- 使用 callCC
squareCCC :: Int -> Cont r Int
squareCCC n = callCC $ \k -> k (n ^ 2)


我们传递给 callCC 一个函数作为参数,后者返回一个暂停的计算 (即返回值的类型为 Cont r a),我们将这个返回值称为"callCC 计算"。原则上callCC 计算就是整个 callCC 的返回值。最关键的地方,也即 callCC 的独特之处,在于 k,传入 callCC 的函数的所接收的参数。这是一个能使整个计算跳出的 弹射按钮: 任何地方对它的调用 k x 将把 x 封装成一个暂停的计算,然后随着控制流返回到 callCC 被调用的地方。这是一个无条件跳转;特别的,在 k 被调用之处往后的计算将被舍弃。从另一种角度看,k 获得了 callCC 之后 剩余的计算;对它的调用将把一个值传入这剩余的计算当中("callCC" 是 "call with current continuation" 的缩写,即 "调用并传入当前的延续")。虽然在上例中,k 所起的作用和 return 并无二致,callCC 向我们展示了一些全新的可能性。

决定什么时候使用 k

[编辑]

callCC 使我们获得了决定何时将何值传入延续的能力。下例将展示它的运用。

Example
Example

例子: 我们的第一个完全使用了 callCC 能力的函数

foo :: Int -> Cont r String
foo x = callCC $ \k -> do
    let y = x ^ 2 + 3
    when (y > 20) $ k "大于20"
    return (show $ y - 4)


foo 是一个稍稍自找麻烦的计算输入参数的平方加三的函数;如果结果大于 20,那么我们的 callCC 函数(在本例中,即整个 foo 函数)将立即返回,并将 "大于二十" 装入一个暂停的计算中,后者随即被传递给 foo。如果不是,我们将结果减去4,用 show 转换成一个字符串,然后和上一种情况一样封装入一个暂停的计算中。值得注意的是,k 在这里起的作用就像命令式语言中的 'return' 语句 一般,即立即退出并返回一个函数。但在Haskell中,k 只是一个语言的一等公民 —— 函数 —— 罢了。 因此你可以做一些将它传递给 when,将它保存在 Reader 中,诸如此类的事。

自然,你可以在 do 代码块中使用 callCC:

Example
Example

例子: 更成熟的使用了do 代码块的 callCC 例子

bar :: Char -> String -> Cont r Int
bar c s = do
    msg <- callCC $ \k -> do
        let s0 = c : s
        when (s0 == "你好") $ k "他们说你好呀"
        let s1 = show s0
        return ("他们似乎在说: " ++ s1)
    return (length msg)


当你用一个值调用 k 时,整个对 callCC 的调用都得到了这个值。从最终效果上看,这使得 k 看起来与其他语言中的 'goto' 语句类似: 在我们的例子中,当我们调用 k 时,它将计算跳转到 callCC 被调用的地方,也就是 msg <- callCC $ ... 这一行。对传递给 callCC 的 do 代码块的执行到此为止。下例中,有一个不会被执行的语句:

Example
Example

例子: 跳出一个函数,其中一行将不被执行

quux :: Cont r Int
quux = callCC $ \k -> do
    let n = 5
    k n
    return 25


quux 将返回 5,而不是 25,因为我们在到达 return 25 这一行之前就跳出了 quux

幕后

[编辑]

我们故意没有使用以往的风格: 通常当我们引入一个函数时,我们会先给出它的类型;但在这个函数上我们决定换一条路。原因很简单: 它的类型相当复杂,而且我们并不能从中一眼看出它的功能,或者它的实现。然而,在刚刚对 callCC 的使用方式进行了演示后,我们认为你已经准备好了。放轻松...

callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a

有了关于 callCC 的一些认识,我们可以试着解读这个类型签名。整个函数的返回值和作参数的函数的返回值类型必须相同(即 Cont r a),因为如果我们不调用 k 那么作为参数的函数的返回值将直接被传入 callCC。那么,k 的类型是什么呢?如同我们之前所说的那样,k 将它的参数装入一个暂停的计算中,然后控制流返回到 callCC 被调用之处;因此,若后者的类型的为 Cont r ak 的参数必须具有 a 的类型。有趣的是,只要它具有 Cont r b 的形式,k 返回值的类型(也就是 b 的类型)并无关紧要。这是因为以 a 为参数产生的暂停的计算将接受 callCC 之后的延续,而不是 k 的调用之后的。

注解

由于 k 返回值类型的形式,下例将产生一个类型错误

quux :: Cont r Int
quux = callCC $ \k -> do
   let n = 5
   when True $ k n
   k 25

k 的返回值并不被局限于某一个类型;然而,when 将它限制成 Cont r (),因此 k 25 的类型和 quux 的不符。我们只需要将末尾的 kreturn 代替,或使用 k () 就可以了。

在这个部分的最后,我们来看看 callCC 的实现。你能从中找到 k 吗?

callCC f = cont $ \h -> runCont (f (\a -> cont $ \_ -> h a)) h

这段代码也许有些难以理解。出乎意料的,ContcallCCreturn(>>=) 的实现都能从他们的类型中自动推导。Lennart Augustsson 的 Djinn [1] 就是这样的一个自动推导程序。同样参见 Phil Gossett 的 Google tech talk: [2] 以了解 Djinn 的理论背景;另外,参见 Dan Piponi 的文章: [3],其中讲述了如何用 Djinn 推导CPS。

例子: 一个复杂的控制结构

[编辑]

我们来看一些更为复杂的控制流操作的样例。第一个节选自 "The Continuation monad" All about monads tutorial,已获得授权。

Example
Example

例子: 在复杂的控制结构中使用Cont Monad

{- 我们用Cont Monad来"跳出"代码块。
以下函数实现了一个复杂的控制结构以处理数字:

输入 (n)     输出                      列表里的内容
=========     ======                    ==========
0-9           n                         无
10-199        (n/2) 的数位数            (n/2) 的数位数
200-19999     n                         (n/2) 的数位数
20000-1999999 (n/2) 倒过来              无
>= 2000000    (n/2) 的数位之和          (n/2) 的数位数
-} 
fun :: Int -> String
fun n = (`runCont` id) $ do
    str <- callCC $ \exit1 -> do                            -- 定义 "exit1"
        when (n < 10) (exit1 (show n))
        let ns = map digitToInt (show (n `div` 2))
        n' <- callCC $ \exit2 -> do                         -- 定义 "exit2"
            when ((length ns) < 3) (exit2 (length ns))
            when ((length ns) < 5) (exit2 n)
            when ((length ns) < 7) $ do
                let ns' = map intToDigit (reverse ns)
                exit1 (dropWhile (=='0') ns')               -- 跳出两层结构
            return $ sum ns
        return $ "(ns = " ++ (show ns) ++ ") " ++ (show n')
    return $ "Answer: " ++ str


fun 接收一个 n 作为参数。它的实现使用了 ContcallCC 以构建一个控制结构,其中再次使用 ContcallCC ,依据 n 的范围来做一些不同的事,正如开头的注释所说那样。让我们一步步看:

  1. 首先,处于代码首层的 (`runCont` id) 仅仅是说我们将构造一个 Cont 块并使用 id 作为延续(换句话说,我们不做改变地将值从暂停的计算中提取出来)。这是有必要的,因为 fun 的返回值类型中并没有 Cont
  2. 我们将以下 callCC do 代码块的结果命名为 str:
    1. n 小于10,我们直接退出,返回 n 的字符串表示。
    2. 否则,我们构造列表 ns,其中包含了 n `div` 2 的各数位。
    3. 将以下 callCC do 代码块的结果(类型为 Int)命名为n'
      1. 如果 length ns < 3 成立,也就是说,n `div` 2 的数位数小于3,我们从此层退出,返回数位数。
      2. 如果 n `div` 2 的数位数小于5,我们从此层退出,返回 n
      3. 如果 n `div` 2 的数位数小于7,我们从此层以及外层退出,返回 n `div` 2 的倒序显示的数位数(一个String)。
      4. 否则,我们从此层退出,返回 n `div` 2 的各数位之和。
    4. 我们从这个 do 代码块退出,返回 String "(ns = X) Y",其中 X 代表 nsn `div` 2 的各数位,Y 代表从内层 do 代码块中返回的结果,即 n'
  3. 最后,我们从整个函数中返回 "Answer: Z",其中 Z 代表我们从 callCC do 代码块中获得的值。

例子: 异常

[编辑]

我们也可以用延续来处理异常。我们将使用 两个 延续: 一个用于处理异常,另一个代表执行成功后的后续计算。下面这个简单的函数将它的两个参数作整数除法,若分母为零则产生异常。

Example
Example

例子: 能够抛出异常的 div

divExcpt :: Int -> Int -> (String -> Cont r Int) -> Cont r Int
divExcpt x y handler = callCC $ \ok -> do
    err <- callCC $ \notOk -> do
        when (y == 0) $ notOk "分母为零"
        ok $ x `div` y
    handler err

{- For example,
runCont (divExcpt 10 2 error) id --> 5
runCont (divExcpt 10 0 error) id --> *** Exception: 分母为零
-}


它是如何工作的?我们使用了两个嵌套的 callCC 调用。第一个给出了当一切正常时使用的延续;第二个则给出了当异常发生时所使用的延续。如果分母不为 0,x `div` y 被传入 ok,计算跳回顶层的 divExcpt。但是,如果分母为 0,我们将一条错误信息传给 notOk,后者将我们从内层的 do 代码块中弹出。我们给出的信息被命名为 err 并传给 handler

下面是一个更为通用的异常处理函数。传入一个暂停中的计算作为第一个参数(更精确的说,这是一个接收一个错误处理函数然后返回一个Cont Monad的函数),以及一个错误处理函数作为第二个参数。本例使用了 MonadCont 类型类 [4],其包含了 Cont 和相对应的 ContT transformer,以及对应的一系列 instance。

Example
Example

例子: 通用的 try ,使用了延续。

import Control.Monad.Cont

tryCont :: MonadCont m => ((err -> m a) -> m a) -> (err -> m a) -> m a
tryCont c h = callCC $ \ok -> do
    err <- callCC $ \notOk -> do
        x <- c notOk
        ok x
    h err


实际使用的例子 try :

Example
Example

例子: 使用 try

data SqrtException = LessThanZero deriving (Show, Eq)

sqrtIO :: (SqrtException -> ContT r IO ()) -> ContT r IO ()
sqrtIO throw = do 
    ln <- lift (putStr "输入一个需要开平方根的数: " >> readLn)
    when (ln < 0) (throw LessThanZero)
    lift $ print (sqrt ln)

main = runContT (tryCont sqrtIO (lift . print)) return


本例中,抛出异常意味着从 callCC 的代码块中弹出。sqrtIO 中的 throw 使得我们从 tryCont 内的 callCC 中跳出.

例子: 协程

[编辑]

本例中,我们定义一个 CoroutineT Monad 以实现 forkyield 函数。fork 将一个暂停的协程压入队列,yield 暂停当前的协程.

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-- 我们使用 GeneralizedNewtypeDeriving 以避免一些枯燥无味的代码。在GHC 7.8及之前的版本中,
-- 这个拓展满足 Safe Haskell。

import Control.Applicative
import Control.Monad.Cont
import Control.Monad.State

-- CoroutineT Monad 只是一个ContT嵌套的StateT,后者包含了暂停的协程。
newtype CoroutineT r m a = CoroutineT {runCoroutineT' :: ContT r (StateT [CoroutineT r m ()] m) a}
    deriving (Functor,Applicative,Monad,MonadCont,MonadIO)

-- 用以操作协程的队列。
getCCs :: Monad m => CoroutineT r m [CoroutineT r m ()]
getCCs = CoroutineT $ lift get

putCCs :: Monad m => [CoroutineT r m ()] -> CoroutineT r m ()
putCCs = CoroutineT . lift . put

-- 从队列弹出/向队列压入协程。
dequeue :: Monad m => CoroutineT r m ()
dequeue = do
    current_ccs <- getCCs
    case current_ccs of
        [] -> return ()
        (p:ps) -> do
            putCCs ps
            p

queue :: Monad m => CoroutineT r m () -> CoroutineT r m ()
queue p = do
    ccs <- getCCs
    putCCs (ccs++[p])

-- 接口。
yield :: Monad m => CoroutineT r m ()
yield = callCC $ \k -> do
    queue (k ())
    dequeue

fork :: Monad m => CoroutineT r m () -> CoroutineT r m ()
fork p = callCC $ \k -> do
    queue (k ())
    p
    dequeue

-- 恢复暂停的协程,直到队列为空。
exhaust :: Monad m => CoroutineT r m ()
exhaust = do
    exhausted <- null <$> getCCs
    if not exhausted
        then yield >> exhaust
        else return ()

-- 在上层的 Monad m 中运行协程。
runCoroutineT :: Monad m => CoroutineT r m r -> m r
runCoroutineT = flip evalStateT [] . flip runContT return . runCoroutineT' . (<* exhaust)

使用样例:

printOne n = do
    liftIO (print n)
    yield

example = runCoroutineT $ do
    fork $ replicateM_ 3 (printOne 3)
    fork $ replicateM_ 4 (printOne 4)
    replicateM_ 2 (printOne 2)

输出:

3
4
3
2
4
3
2
4
4

Template:Haskell/NotesSection

  1. \x -> ($ x),展开既得 \x -> \k -> k x
  2. attoparsec 是一个使用CPS以提高性能的例子。
  3. 练习: 验证并证明它满足Monad的性质。
  4. mtl 包,模块 Template:Haskell lib.