module Parse1 where

import List (transpose)
import qualified Data.Map as M

data Symbol
    = Sentence
    | NounPhrase
    | VerbPhrase
    | Determiner
    | Noun
    | TransitiveVerb
    | Preposition
    | PrepositionPhrase
    deriving (Eq, Ord, Show)

type Word = String

lexicon :: [(Symbol, Word)]
lexicon = [(NounPhrase, "time"),
           (Noun, "time"),
           (TransitiveVerb, "time"),
           (NounPhrase, "fruit"),
           (Noun, "fruit"),
           (NounPhrase, "flies"),
           (Noun, "flies"),
           (VerbPhrase, "flies"),
           (TransitiveVerb, "like"),
           (Preposition, "like"),
           (Determiner, "a"),
           (Determiner, "an"),
           (Noun, "arrow"),
           (Noun, "banana")]

syntax :: [(Symbol, Symbol, Symbol)]
syntax = [(Sentence, NounPhrase, VerbPhrase),
          (NounPhrase, Determiner, Noun),
          (NounPhrase, Noun, Noun),
          (Noun, Noun, Noun),
          (VerbPhrase, VerbPhrase, PrepositionPhrase),
          (PrepositionPhrase, Preposition, NounPhrase),
          (VerbPhrase, TransitiveVerb, NounPhrase)]

data Parse = Parse Symbol (Maybe (Parse, Int, Parse))
    deriving (Eq, Ord, Show)

type Cell = M.Map Symbol [Parse]

isParse :: Parse -> [Word] -> Bool
isParse (Parse s Nothing) [w] = (s,w) `elem` lexicon
isParse (Parse s Nothing) _ = False
isParse (Parse s (Just (pl@(Parse sl _),m,pr@(Parse sr _)))) ws
  = (s,sl,sr) `elem` syntax && isParse pl wsl && isParse pr wsr
  where (wsl, wsr) = splitAt m ws

parses :: [Word] -> Symbol -> [Parse] -- backtracking
parses []  s = []
parses [w] s = [ Parse s Nothing | (s',w') <- lexicon, s == s', w == w' ]
parses ws s = concat (zipWith3 f (prefixes ws) [1..] (tail (suffixes ws)))
  where f wsl m wsr = [ Parse s (Just (pl,m,pr))
                      | (s',sl,sr) <- syntax
                      , s == s'
                      , pl <- parses wsl sl
                      , pr <- parses wsr sr ]

unary :: Word -> Cell
unary w = M.fromListWith (++)
          [ (s, [Parse s Nothing]) | (s,w') <- lexicon, w == w' ]

innerProd :: [Cell] -> [Cell] -> Cell
innerProd cls crs = M.unionsWith (++) (zipWith3 prod cls [1..] crs)
  where prod cl m cr = M.fromListWith (++)
                       [ (s, [ Parse s (Just (pl,m,pr))
                             | pl <- pls, pr <- prs ])
                       | (sl,pls) <- M.toList cl
                       , (sr,prs) <- M.toList cr
                       , (s,sl',sr') <- syntax, sl == sl', sr == sr' ]

prefixes :: [a] -> [[a]]
prefixes [] = []
prefixes (x:xs) = [x] : map (x:) (prefixes xs)

suffixes :: [a] -> [[a]]
suffixes [] = []
suffixes l@(_:xs) = l : suffixes xs

cyk :: [Cell] -> [[Cell]]
cyk cells = pyramid
  where pyramid = takeWhile (not . null) (cells : map row (prefixes pyramid))
        row trapezoid = zipWith innerProd
                            (transpose trapezoid)
                            (transpose (zipWith drop [1..] (reverse trapezoid)))

parses' :: [Word] -> Symbol -> [Parse] -- CYK dynamic programming
parses' ws s = M.findWithDefault [] s cell
  where [cell] = last (cyk (map unary ws))

