module Stretch where

import Distr
import DistrTest

-- The state has four components
-- 1. longest stretch length so far
-- 2. current stretch length
-- 3. last flip
-- 4. total number of heads so far
type State = (Int, Int, Flip, Int)

transition :: State -> Flip -> State
transition (longest, current, last, heads) next
    = (max current' longest, current', next, heads')
    where current' = if last == next then current + 1 else 1
          heads' = if next == Heads then heads + 1 else heads

stretch :: (Ord p, Fractional p, Distr d, Random p) => Int -> Int -> d p State
stretch 0 k = unit (0,0,Heads,0)
stretch n k = bind (stretch (n - 1) k) (\s ->
              bind coin (\c ->
              let (longest, current, last, heads) = transition s c
              in if heads <= k
                 then unit (longest, current, last, heads)
                 else choose []))
