module Stretch where

import Distr
import DistrTest

-- The state has three components
-- 1. longest stretch length so far
-- 2. current stretch length
-- 3. last flip
type State = (Int, Int, Flip)

transition :: State -> Flip -> State
transition (longest, current, last) next
    = (max current' longest, current', next)
    where current' = if last == next then current + 1 else 1

stretch :: (Ord p, Fractional p, Distr d, Random p) => Int -> d p State
stretch 0 = unit (0,0,Heads)
stretch n = bind (stretch (n - 1)) (\s ->
            bind coin (\c ->
            unit (transition s c)))
