Post

Monad 5 - Continuation Passing Style

Monad 5 - Continuation Passing Style

Before starting

3편을 통해 우리는 모나드를 좀 더 자유롭게 사용하기 위한 다양한 변환들에 대해 알아봤고, 그 중 일부 내용의 이론적인 고찰은 4편에서 정리했다. 이번 글에서는 주제를 조금 바꿔서, 함수형 언어에서 값 대신 “Continuation”을 넘기는 방식에 대해 알아볼 것이다.

Continuation

먼저 일반적인 함수 호출 구조를 생각해보자.

1
2
3
4
5
add :: Int -> Int -> Int
add x y = x + y

result :: Int
result = add (add 1 2) (add 3 4)

이 코드를 실행할 때 다음의 과정들이 일어날 것이다.

  1. add 1 2 계산 -> 3 리턴
  2. add 3 4 계산 -> 7 리턴
  3. add 3 7 계산 -> 10 리턴
  4. 10result에 바인딩

이 과정에서 add 1 2의 결과인 3을 기억해서, 이걸 바깥쪽의 add의 호출에 사용해야 한다는 정보를 Call Stack이 기억하고 있다. 그런데 기억을 한다는 것은 어떤 식으로든 명시적인 형태로 갖고 있다는 것인데, 그럼 그걸 노출할 수도 있지 않을까?

이걸 Haskell에서는 Continuation이라는 형태로 표현한다.

1
2
add 1 2
addCPS 1 2 (\result -> result)

addCPS의 마지막 인자에 “결과를 받았을 때 수행할 함수”가 들어간다.

Continuation에 결과를 전달하기만 하면 되니 addCPS를 구현하는 것은 간단하다.

1
2
3
4
5
add :: Int -> Int -> Int
add x y = x + y

addCPS :: Int -> Int -> (Int -> r) -> r
addCPS x y k = k (x + y)

타입 선언을 봐도 (Int -> r)을 인자로 받고, 그 결과로 나온 r이 최종적인 리턴 타입이 됨을 알 수 있다. 아까의 result와 같이 복잡한 타입을 Continuation으로 바꾸면 이렇게 표현할 수 있다.

1
2
3
4
5
6
7
8
9
nested :: Int
nested = add (add 1 2) (add 3 4)

nestedCPS :: (Int -> r) -> r
nestedCPS k =
    addCPS 1 2 $ \r1 ->     -- add 1 2
    addCPS 3 4 $ \r2 ->     -- add 3 4
    addCPS r1 r2 $ \r3 ->   -- add r1 r2
    k r3                    -- result

즉, Call Stack이 하던 일을 그대로 lambda로 표현한 것이다.

Cont Monad

결국 Continuation이 하는 일을 보면 (a -> r) -> r이 반복된다. 그러면 이를 타입으로 만들 수 있지 않을까?

1
newtype Cont r a = Cont { runCont :: (a -> r) -> r }

그리고 이렇게 만들어진 Cont는 아래와 같이 모나드가 된다.

1
2
3
4
5
6
7
8
9
10
11
12
13
instance Functor (Cont r) where
    fmap f (Cont c) = Cont $ \k -> c (k . f)

instance Applicative (Cont r) where
    pure a = Cont $ \k -> k a
    (Cont f) <*> (Cont x) = Cont $ \k -> f $ \g -> x $ \a -> k (g a)

instance Monad (Cont r) where
    return = pure

    (Cont c) >>= f = Cont $ \k ->
        c $ \a ->           -- c를 실행해서 a를 얻음
        runCont (f a) k     -- f a를 k에 전달

그러면 아래와 같이 Cont 모나드를 이용해서 좀 더 간편하게 쓸 수 있다.

1
2
3
4
5
6
7
computation :: Cont r Int
computation = do
    x <- return 3
    y <- return 4
    return (x + y)

runCont computation id      -- 7

callCC

여기서 Cont 모나드를 끝내면 이건 그냥 “계산”을 추상화한 모나드에 지나지 않는다. 여기에 여러 가지 개념들이 추가되어야 비로소 Cont 모나드를 강력하게 쓸 수 있다.

우선 callCC 함수를 살펴보자. 이 함수의 이름은 “call with Current Continuation”에서 따왔다.

1
2
callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a
callCC f = Cont $ \k -> runCont (f (\a -> Cont $ \_ -> k a)) k

그런데 타입부터 직관적으로 이해하기 상당히 어렵다. 하나씩 분해해보자.

callCC는 인자로 f 하나만을 받는다. 이 f의 형태는 (a -> Cont r b) -> Cont r a인데, 보다시피 f 또한 인자로 a -> Cont r b 타입의 함수 k를 받고 있다.

그런데 k의 타입이 뭔가 이상하다. k에 대해서만 생각해보면, 인자로 받은 a와 리턴한 Cont r b는 전혀 다른 타입이다. 이 사라진 정보 af의 리턴값에서 Cont r a의 형태로 존재하는데, 이 말은 곧 k를 호출하면 그 계산이 통째로 건너뛰어진다는 의미와 동일하다. 이러한 기조는 callCC의 리턴값 Cont r a까지도 유지된다.

이제 함수의 구현도 파악해보자. 가장 안쪽에 존재하는 a -> Cont $ \_ -> k a부터 생각해보면, 나중에 들어올 Continuation을 무시하고 k aCont 모나드에 캡쳐한다. 그리고 그 결과를 runCont를 통해 즉시 계산하게 만든다. 즉, 이 함수가 “탈출” 역할을 하게 된다.

그러니까 정리하자면, callCC 내부에서 언제든 k를 호출하는 순간, 내부 구문이 얼마나 남았든지에 관계 없이 callCC 함수는 종료되고 마지막에 계산한 k의 결과만이 리턴되는 것이다.

shift, reset

callCC는 Undelimited Continuation이라고 부른다. 이는 경계가 없다는 뜻으로, 그 말 그대로 k가 호출되면 그 구문의 끝까지 점프해버릴 뿐이지 “어디까지” 점프해야하는지는 딱히 정해져있지 않다. callCC만으로도 충분히 강력하지만, 무조건 나머지 문맥을 전부 스킵한다는 점이 좀 불편하게 작용할 때가 있다.

이를 보완하기 위해 나온 것이 Delimited Continuation이다. 이번엔 “어디까지”라는 경계가 필요하므로, 아래의 두 함수가 필요하다.

1
2
3
4
5
6
7
8
9
newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }

reset :: Monad m => ContT r m r -> ContT s m r
reset c = ContT $ \k ->
    runContT c return >>= k

shift :: Monad m => ((a -> m r) -> ContT r m r) -> ContT r m a
shift f = ContT $ \k ->
    runContT (f k) return

우선 경계에 대한 타입 r이 명시적으로 필요하다. reset이 그 경계값을 설정하는 역할인데, return을 Continuation 안에 넣고 있음을 볼 수 있다. 모나드에서의 return의 역할을 생각해보자. 모나드 내부의 계산이 return을 만나면 그냥 모나드 내부의 값을 꺼내고 종료한다. 그래서 reset이 걸린 지점에서 계산이 멈추는 것이다.

shiftcallCC의 역할을 하는데, 여기서는 callCC에선 이후 Continuation을 아예 버리는 것과는 달리, kf에게 그대로 넘긴다. 이렇기 때문에 f 안에서 k를 여러번 호출할 수 있다.

1
2
3
4
5
6
7
8
9
result :: Int
result = runIdentity $ runContT
    (reset $ do
        x <- shift $ \k -> do
            a <- k 1
            b <- k 2
            return (a + b)
        return (x * 10))
    return

callCC와는 달리 경계를 설정한다는 개념이 잘 이해가 가지 않을 수 있으니, 위의 간단한 예시를 확인해보자. 먼저 reset으로 인해 경계가 설정되었다. 이제 runContT로 전체가 둘러싸여 있지만 점프를 하면 reset의 마지막 구문까지만 이동하게 된다. shift 내부의 k를 호출하면, 일시적으로 흐름을 뛰어넘어 위에서 설정한 reset의 경계선으로 이동한다. 그러면 a에는 10이 들어가게 될 것이다. 그리고 다음 문장을 실행하게 되고, 여기서도 reset의 경계선으로 이동하니 b에는 20이 들어가게 된다. 그렇기 때문에 결과적으로 30이 출력될 것이다.

Some Examples

그럼 이제 ContcallCC를 활용한 패턴을 몇 개 살펴보자. 이러한 다양한 패턴을 모두 같은 방식으로 구현이 가능한 것이 CPS의 독보적인 장점이라고 생각한다.

Early return

1
2
3
4
5
6
7
8
productList :: [Int] -> Int
productList xs = runCont (callCC $ \exit -> do
    forM_ xs $ \x -> do
        when (x == 0) $ exit 0
    return (product xs)) id

productList [1, 2, 3, 4]    -- 24
productList [1, 2, 0, 4]    -- 0

리스트를 순회하는 중 0이 나오면 즉시 0을 리턴하고 종료하도록 하여 곱셈이 실행되지 않게 한다.

Exception

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
safeDiv :: Int -> Int -> Cont r (Either String Int)
safeDiv x y = callCC $ \throw -> do
    when (y == 0) $ throw (Left "Division by zero")
    return $ Right (x `div` y)

compute :: Int -> Int -> Int -> Cont r (Either String Int)
compute a b c = callCC $ \throw -> do
    x <- callCC $ \innerThrow -> do
        when (a < 0) $ throw (Left "Negative divided")
        when (b == 0) $ innerThrow (Left "First division failed")
        return $ Right (a `div` b)
    y <- callCC $ \innerThrow -> do
        when (a < 0) $ throw (Left "Negative divided")
        when (c == 0) $ innerThrow (Left "Second division failed")
        return $ Right (a `div` c)
    return $ (+) <$> x <*> y

위의 compute 함수처럼 callCC 자체는 중복해서 사용할 수 있으며, 각각의 k는 독립적인 탈출 범위를 갖는다. 물론 각 블록 내에서 처음으로 수행된 k에 대해서만 즉시 종료하고 그 값을 리턴한다.

Break loop

1
2
3
4
5
6
7
8
loop :: Cont r [Int]
loop = callCC $ \break -> do
    result <- forM [1..10] $ \i -> do
        when (i == 5) $ break []
        return i
    return result

runCont loop id

루프를 돌다가 멈춰야 할 때 k를 호출하면 된다.

Continue loop

1
2
3
4
5
6
7
8
9
loop :: Cont r [Int]
loop = do
    results <- forM [1..10] $ \i ->
        callCC $ \continue -> do
            when (i == 5) $ continue Nothing
            return (Just i)
    return $ catMaybes results
    
runCont loop id

break와 다르게 전체 루프 안에 callCC가 들어간 형태로 구현하면 된다. 다만, k는 대신 리턴해줄 값이 반드시 필요하므로 정말로 건너뛰기를 구현하고 싶다면 위의 예제처럼 Maybe 모나드를 같이 쓰거나 해야 한다.

혹은 Haskell에서 제공하는 runContTlift를 이용하면 보다 간편하게 구현할 수 있다. 이름에서 짐작할 수 있듯이 Cont 모나드에 대한 Transformer이다.

1
2
3
4
5
6
7
8
9
10
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Class

loop :: IO ()
loop = runContT (
    forM_ [1..10] $ \i ->
        callCC $ \skip -> do
            when (i == 5) $ skip ()
            lift $ print i
    ) return

물론 이 경우는 IO 모나드로 리턴이 나온다. 특정 모나드로의 변환이 아니라 순수하게 값을 모으고 싶다면 첫번째 예시처럼 써야한다.

Tail Recursion

흔히들 알고 있는 Tail Recursion은 재귀 함수의 마지막 동작이 재귀 호출인 경우를 의미한다. 이 경우엔 현재 스택을 재사용할 수 있어 스택이 쌓이지 않아야 한다.

그러나 아래와 같이 그냥 재귀함수를 호출하게 되면 이러한 최적화가 이루어지지 않아 스택이 계속 쌓이게 된다.

1
2
3
factorial :: Int -> Int
factorial 0 = 1
factorial n = n * factorial (n - 1)

이를 해결하기 위해서는 다양한 방법이 있지만 대부분 아래와 같이 Accumulator를 사용하는 방법을 먼저 배웠을 것이다.

1
2
3
4
5
6
factorialTail :: Int -> Int -> Int
factorialTail 0 acc = acc
factorialTail n acc = factorialTail (n - 1) (n * acc)

factorial :: Int -> Int
factorial n = factorialTail n 1

그런데 이 CPS로도 Tail Recursion을 최적화할 수 있다.

1
2
3
factorial :: Int -> (Int -> r) -> r
factorial 0 k = k 1
factorial n k = factorial (n - 1) (\r -> k (n * r))

이렇게 구현하면 “이후에 할 일”이 이미 Continuation k에 들어있는 상태라 항상 Tail Recursion 최적화가 이루어진다.

Coroutine

몇몇 언어에는 Coroutine이라는 개념이 있다. 함수의 실행을 일시정지하고 필요시 다시 재개할 수 있게 한다. 이는 C#에 yield, Go의 go (이쪽은 고루틴이라고 따로 부른다), 그리고 대부분의 언어가 지원하는 async, await 등의 키워드가 이를 구현한 것이다. 이러한 패턴들 또한 CPS로 구현이 가능하다.

먼저 yield와 같은 “일시정지” 기능을 어떻게 구현할 수 있는지 확인해보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Class

data Step a b
    = Yield a (b -> Coroutine a b)
    | Done

newtype Coroutine a b = Coroutine { resume :: ContT (Step a b) IO b }

yield :: a -> ContT (Step a b) IO ()
yield value = ContT $ \k ->
    return $ Yield value k  -- "yield 이후의 나머지 계산"을 저장

-- Generator
counter :: ContT (Step Int b) IO ()
counter = do
    yield 1
    lift $ putStrLn "Between 1 and 2"
    yield 2
    lift $ putStrLn "Between 2 and 3"
    yield 3

-- Executors
runStep :: ContT (Step a b) IO b -> IO (Step a b)
runStep c = runContT c (\_ -> return Done)

collect :: ContT (Step a ()) IO () -> IO [a]
collect c = do
    step <- runStep c
    case step of
        Done        -> return []
        Yield v k   -> do
            rest <- collect (k ())
            return (v : rest)

foreach :: ContT (Step a ()) IO () -> (a -> IO ()) -> IO ()
foreach c handler = do
    step <- runStep c
    case step of
        Done        -> return ()
        Yield v k   -> do
            handler v
            foreach (k ()) handler

-- main
main :: IO ()
main = do
    values <- collect counter
    print values

    foreach counter $ \v ->
        putStrLn $ "Value: " ++ show v

이제 async/await를 어떻게 구현할 수 있는지 살펴보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
type Callback a = a -> IO ()
type Async a = Callback a -> IO ()

-- Sample async work
fetch :: Int -> String -> Async String
fetch delay url callback = do
    delayThread (delay * 1000000)
    callback $ "Response from " ++ url

data AsyncStep a
    = Pending (IO ())
    | Complete a

newtype AsyncT a = AsyncT { runAsyncT :: ContT (AsyncStep a) IO a }

await :: Async a -> AsyncT a
await asyncOp = AsyncT $ ContT $ \k ->
    return $ Pending $ asyncOp (runContT (k undefined) (\r -> do
        step <- runContT (k r) (\a -> return (Complete a))
        case step of
            Pending io -> io
            Complete _ -> return ()))


-- Callback Style
type AsyncCPS a = (a -> IO ()) -> IO ()

awaitCPS :: AsyncCPS a -> ContT r IO a
awaitCPS asyncOp = ContT $ \k -> asyncOp k  -- 결과를 받으면 k를 호출하여 연산 재개

결국 yieldasync/await나 내부 구현은 비슷비슷하다. 차이점이라면 yield는 Continuation을 Yield에 저장하고, await는 이걸 전달해준다는 것이다. 여담으로 C# 등 async/await를 지원하는 언어의 경우 컴파일러가 내부적으로 State Machine을 생성하는데, 그 State Machine이 위의 구현과 동일한 구조로 동작한다.

Nondeterministic

이젠 reset, shift를 사용한 예시를 보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
type NonDet a = ContT [a] Identity a

choose :: [a] -> NonDet a
choose xs = shift $ \k ->
    lift $ fmap concat $ mapM k xs

runNonDet :: NonDet a -> [a]
runNonDet m = runIdentity $ runContT
    (reset $ fmap (:[]) m)
    return

pairs :: [(Int, Int)]
pairs = runNonDet $ do
    x <- choose [1, 2, 3]
    y <- choose [10, 20]
    return (x, y)

선택지가 여러개 존재하는데, 각각의 선택지마다 나올 수 있는 답을 모두 모은 것이다. 즉, 위의 pairs[(1, 10), (1, 20), (2, 10), (2, 20), (3, 10), (3, 20)]이 된다. 이러한 식의 실행은 callCC로는 불가능하다. k를 처음 본 순간 바로 탈출하기 때문이다.

This post is licensed under CC BY 4.0 by the author.