/*
 * This program translates Distr and DistrTest from Haskell to C++.
 */

#include <iostream>
#include <cassert>
#include <utility>
#include <algorithm>
#include <map>
#include <list>

// Define sampling procedures and three ways to make them: unit, choose, bind.

template <typename P, typename A> class sample {
    class procedure {
    public:
        virtual A conduct() const = 0;
        virtual procedure *clone() const = 0;
        virtual ~procedure() {}
    };

    class unit: public procedure {
        A a;
    public:
        unit(const A &_a): a(_a) {}
        procedure *clone() const { return new unit(a); }
        A conduct() const { return a; }
    };

    template <typename B, typename K>
    class bind: public procedure {
        sample<P,B> m;
        K k;
    public:
        bind(const sample<P,B> &_m, const K &_k): m(_m), k(_k) {}
        procedure *clone() const { return new bind(m, k); }
        A conduct() const { return k(m.proc->conduct()).proc->conduct(); }
    };

    class choose: public procedure {
        std::list<std::pair<P,A> > l;
    public:
        choose(std::list<std::pair<P,A> > _l): l(_l) {}
        template <typename I> choose(const I &begin, const I &end):
            l(begin, end) {}
        procedure *clone() const { return new choose(l); }
        A conduct() const {
            P prob(random()/(P(RAND_MAX)+1));
            typename std::list<std::pair<P,A> >::const_iterator i = l.begin();
            while (true) {
                assert(i != l.end());
                if ((prob -= i->first) < 0) return i->second;
                ++i;
            }
        }
    };

public:
    std::auto_ptr<procedure> proc;

    // copy constructor
    sample(const sample &_): proc(_.proc->clone()) {}

    // unit
    sample(const A &a): proc(new unit(a)) {}

    // bind
    template <typename B, typename K>
    sample(const sample<P,B> &m, const K &k): proc(new bind<B,K>(m,k)) {}

    // choose
    template <typename I>
    sample(I begin, const I &end): proc(new choose(begin, end)) {}
};

// Print a sampling procedure by conducting it 100 times.

template <typename P, typename A>
std::ostream &operator<<(std::ostream &os, const sample<P,A> &s) {
    for (int t = 100; t > 0; --t) {
        os << s.proc->conduct();
        if (t > 1) os << ' ';
    }
    return os;
}

// Define probability tables and three ways to make them: unit, choose, bind.

template <typename P, typename A> class table {
    void add(A a, P p) {
        if (p) {
            typename map::iterator i = unTable.find(a);
            if (i == unTable.end())
                unTable.insert(std::make_pair(a, p));
            else
                i->second += p;
        }
    }

public:
    typedef std::map<A,P> map;
    map unTable;

    // unit
    table(const A &a) {
        unTable.insert(std::make_pair(a, 1));
    }

    // bind
    template <typename B, typename K>
    table(const table<P,B> &m, const K &k) {
        for (typename std::map<B,P>::const_iterator i = m.unTable.begin();
             i != m.unTable.end(); ++i) {
            if (i->second) {
                const table ka(k(i->first));
                for (typename map::const_iterator j = ka.unTable.begin();
                     j != ka.unTable.end(); ++j)
                    add(j->first, i->second * j->second);
            }
        }
    }

    // choose
    template <typename I>
    table(I begin, const I &end) {
        while (begin != end) {
            std::pair<P,A> p(*begin);
            add(p.second, p.first);
            ++begin;
        }
    }
};

// Specify how to print probability tables.

template <typename P, typename A>
std::ostream &operator<<(std::ostream &os, const table<P,A> &t) {
    bool first = true;
    for (typename table<P,A>::map::const_iterator i = t.unTable.begin();
         i != t.unTable.end(); ++i) {
        if (first) first = false; else os << ", ";
        os << i->first << ' ' << i->second;
    }
    return os;
}

// Define the function "uniform" to make a uniform probability distribution.

template <typename P, typename I>
class uniform_iterator {
    P p;
    I i;
public:
    uniform_iterator(const P &_p, const I &_i): p(_p), i(_i) {}
    uniform_iterator &operator++() { ++i; return *this; }
    std::pair<P, typename std::iterator_traits<I>::value_type> operator*() {
        return std::make_pair(p, *i);
    }
    bool operator==(const uniform_iterator &_) const {
        return p == _.p && i == _.i;
    }
    bool operator!=(const uniform_iterator &_) const {
        return p != _.p || i != _.i;
    }
};

namespace std {
    template <typename P, typename I>
    struct iterator_traits<uniform_iterator<P, I> > {
        typedef pair<P, typename iterator_traits<I>::value_type> value_type;
        typedef typename iterator_traits<I>::difference_type difference_type;
        typedef input_iterator_tag iterator_category;
        typedef value_type pointer;
        typedef value_type reference;
    };
}

template <template <typename, typename> class D,
          typename P, typename A, typename I>
D<P,A> uniform(const I &begin, const I &end) {
    const P p = 1/P(std::distance(begin, end));
    return D<P,A>(uniform_iterator<P,I>(p, begin),
                  uniform_iterator<P,I>(p, end));
}

// Define the type "flip" of coin-flip outcomes, and specify how to print them.

enum flip { Heads, Tails };

std::ostream &operator<<(std::ostream &os, flip f) {
    return os << (f == Heads ? "Heads" : "Tails");
}

// Define the function "coin" to toss a fair coin.

template <template <typename, typename> class D, typename P>
D<P,flip> coin() {
    const flip ht[] = {Heads, Tails};
    return uniform<D, P, flip, const flip *>
                  (ht, ht + sizeof(ht)/sizeof(ht[0]));
}

// coinsFast

template <template <typename, typename> class D, typename P>
class coinsFast_inner_loop {
    int l;
public:
    typedef D<P,int> result_type;
    coinsFast_inner_loop(int _l): l(_l) {}
    D<P,int> operator() (flip c) const {
        return c == Heads ? l + 1 : l;
    }
};

template <template <typename, typename> class D, typename P>
class coinsFast_outer_loop {
public:
    typedef D<P,int> result_type;
    D<P,int> operator() (int l) const {
        return D<P,int>(coin<D,P>(), coinsFast_inner_loop<D,P>(l));
    }
};

template <template <typename, typename> class D, typename P>
D<P,int> coinsFast(int n) {
    return n == 0 ? 0
                  : D<P,int>(coinsFast<D,P>(n-1),
                             coinsFast_outer_loop<D,P>());
}

// coinsSlow

template <template <typename, typename> class D, typename P>
class coinsSlow_inner_loop {
    flip c;
public:
    coinsSlow_inner_loop(flip _c): c(_c) {}
    D<P,int> operator() (int l) const {
        return c == Heads ? l + 1 : l;
    }
};

template <template <typename, typename> class D, typename P>
D<P,int> coinsSlow(int n);

template <template <typename, typename> class D, typename P>
class coinsSlow_outer_loop {
    int n;
public:
    coinsSlow_outer_loop(int _n): n(_n) {}
    D<P,int> operator() (flip c) const {
        return D<P,int>(coinsSlow<D,P>(n-1), coinsSlow_inner_loop<D,P>(c));
    }
};

template <template <typename, typename> class D, typename P>
D<P,int> coinsSlow(int n) {
    return n == 0 ? 0
                  : D<P,int>(coin<D,P>(),
                             coinsSlow_outer_loop<D,P>(n));
}

// coinsMemo (reusing coinsSlow_inner_loop above)

template <template <typename, typename> class D, typename P>
class coinsMemo_outer_loop {
    D<P,int> remainder;
public:
    coinsMemo_outer_loop(D<P,int> _remainder): remainder(_remainder) {}
    D<P,int> operator() (flip c) const {
        return D<P,int>(remainder, coinsSlow_inner_loop<D,P>(c));
    }
};

template <template <typename, typename> class D, typename P>
D<P,int> coinsMemo(int n) {
    return n == 0 ? 0
                  : D<P,int>(coin<D,P>(),
                             coinsMemo_outer_loop<D,P>(coinsMemo<D,P>(n-1)));
}

// Main program.

int main() {
    srandom(time(0));

    const int n = 20;

    std::cout << "Testing sample:" << std::endl;
    std::cout << "coinsFast" << std::endl;
    std::cout << coinsFast<sample,double>(n) << std::endl;
    std::cout << "coinsSlow" << std::endl;
    std::cout << coinsSlow<sample,double>(n) << std::endl;
    std::cout << "coinsMemo" << std::endl;
    std::cout << coinsMemo<sample,double>(n) << std::endl;
    std::cout << std::endl;

    std::cout << "Testing table:" << std::endl;
    std::cout << "coinsFast" << std::endl;
    std::cout << coinsFast<table,double>(n) << std::endl;
    std::cout << "coinsSlow" << std::endl;
    std::cout << coinsSlow<table,double>(n) << std::endl;
    std::cout << "coinsMemo" << std::endl;
    std::cout << coinsMemo<table,double>(n) << std::endl;
    std::cout << std::endl;

    return 0;
}
