{-# OPTIONS -fglasgow-exts #-}

module Tree where

import Control.Monad (liftM)
import Data.List (partition, maximumBy, (\\))
import Data.Maybe (fromJust)
import Data.Map (Map)
import qualified Data.Map as Map
import Text.PrettyPrint
import Finite
import Pretty

default (Rational)

data Value = A | B | C deriving (Eq, Ord, Show, Enum, Bounded)

instance Finite Value where
    everything = enumEverything
    cardinality = enumCardinality

instance Pretty Value

data Process prob result
    = End result
    | Say Value (Process prob result)
    | Ask (Map Value (Process prob result))
    | Pick [(prob, (Process prob result))]
    deriving (Eq, Show)

type Policy = Process () ()

instance (Num prob) => Monad (Process prob) where
    return        = End
    End a   >>= k = k a
    Say v m >>= k = Say v (m >>= k)
    Ask m   >>= k = Ask (Map.map (>>= k) m)
    Pick l  >>= k = Pick [ (p, m >>= k) | (p, m) <- l ]

instance (Pretty prob, Pretty a) => Pretty (Process prob a) where
    pretty (End a)   = pretty a
    pretty (Say v m) = sep [text "Say" <+> pretty v, pretty m]
    pretty (Ask m)   = text "Ask" <+> vcat
        [ pretty v <> colon <+> pretty m | (v,m) <- Map.toList m ]
    pretty (Pick l)  = text "Pick" <+> vcat
        [ pretty p <> colon <+> pretty m | (p,m) <- l ]

say :: Value -> Process m ()
say v = Say v (End ())

ask :: [Value] -> Process m Value
ask vs = Ask (Map.fromList [ (v, End v) | v <- vs ])

pickUniform :: (Fractional prob) => [Value] -> Process prob Value
pickUniform vs = Pick [ (p, End v) | v <- vs ]
    where p = recip (fromIntegral (length vs))

host = do
    prize <- pickUniform everything
    initial <- ask everything
    open <- pickUniform (everything \\ [initial, prize])
    say open
    final <- ask everything
    say prize
    return (if final == prize then 1 else 0)

contestant2 = do
    initial <- pickUniform everything
    say initial
    open <- ask everything
    say (head (everything \\ [initial, open]))
    prize <- ask everything
    return ()

data Solution a = Solution { policy :: forall prob. Process prob (), utility :: a }

instance (Eq a) => Eq (Solution a) where
    Solution m1 u1 == Solution m2 u2 = u1 == u2 && m1 == (m2 :: Policy)

instance (Show a) => Show (Solution a) where
    show (Solution m u) = show "Solution {strategy = " ++ show (m :: Policy)
                                     ++ ", utility = " ++ show u ++ "}"

instance (Pretty a) => Pretty (Solution a) where
    pretty (Solution m u) = text "policy =" <+> pretty (m :: Policy) $+$
                            text "utility =" <+> pretty u

solve :: (Ord a, Num a) => Process a a -> Solution a
solve = solve' . unpick 1

solve' :: (Ord a, Num a) => [(a, Process a a)] -> Solution a
solve' l@((_, End _) : _)
    = Solution (End ()) (sum (map (\ (p, End u) -> p * u) l))
solve' l@((_, Say _ _) : _)
    = Solution (Ask (Map.map policy m)) (Map.fold (+) 0 (Map.map utility m))
    where m = Map.map solve'
            $ Map.fromListWith (++)
            $ map (\ (p, Say v k) -> (v, unpick p k)) l
solve' l@((_, Ask m) : _)
    = maximumBy (\s1 s2 -> utility s1 `compare` utility s2)
        [ Solution (Say v k) u
        | v <- Map.keys m, let Solution k u = solve' (concatMap (f v) l) ]
    where f v (p, Ask m) = unpick p (fromJust (Map.lookup v m))

unpick :: (Num a) => a -> Process a b -> [(a, Process a b)]
unpick prob (Pick chances) = concatMap (\ (p, k) -> unpick (prob * p) k) chances
unpick prob process        = [(prob, process)]
