#    Copyright (C) 2004 Paul Harrison
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

""" MML mixture estimator. 

    This module has functionality similar to Chris Wallace's "Snob" programs
    (but lacks some features found in Snob).
"""

from numarray import *
from numarray import random_array, records
import random

import estimate



class Composite_estimate(estimate.Estimate):
    """ Concatenate several sub-messages into one message.
        (Useful as a class estimate for Mixture_estimate)
        
        data: record.array with a field for each prior in "priors"
          
        priors: list of (Estimate class, prior parameters... )
            Specification of the Estimate class and prior parameters to use 
            for each field in the data.
                    
        Fields:
        
            estimates : list of Estimate
                An estimate for each column of the data.
    
    """
    
    def __init__(self, data, priors, weighting=None):
        estimate.Estimate.__init__(self, data, weighting)
        
        self.estimates = [ priors[i][0](data.field(i), weighting=weighting, *priors[i][1:])
                           for i in xrange(len(priors)) ]
                           
        for item in self.estimates:
            self.has_prior = self.has_prior and item.has_prior
            self.violates_prior = self.violates_prior or item.violates_prior
        
    def __repr__(self):
        result = estimate.Estimate.__repr__(self)
        
        for estimate_ in self.estimates:
            result += ("\n"+repr(estimate_)).replace("\n","\n    ")
        
        return result

    def dimensions(self):
        return sum([ estimate.dimensions() for estimate in self.estimates ])

    def neglog_prior(self):
        return sum([ estimate.neglog_prior() for estimate in self.estimates ])
        
    def neglog_fisher(self):
        return sum([ estimate.neglog_fisher() for estimate in self.estimates ])

    def total_neglog_data(self):
        return sum([ estimate.total_neglog_data() for estimate in self.estimates ])
        
    def neglog_data(self):
        return sum([ estimate.neglog_data() for estimate in self.estimates ])



class _Mixture_estimate(estimate.Estimate):
    """ A mixture estimate, consisting of a set of classes.
    
        You should not instantiate this class directly, instead use
        Mixture_estimate.
        
        Fields:
        
            class_likelihoods : Discrete_estimate
                Estimate of the chance a random datum belongs to each class.
                
            classes : list of Estimate
                Class estimates.
            
            assignments : N-data x N-class array of float
                Partial-assignment matrix. The likelihood of each datum belonging
                to each class.
    """

    def __init__(self, data, prior, suggested_assignments, weighting=None):        
        estimate.Estimate.__init__(self, data, weighting)
        
        n_classes = suggested_assignments.shape[1]
        
        weighted_assignments = suggested_assignments * self.weighting[:,NewAxis]
        
        total = sum(weighted_assignments)
        self.class_likelihoods = estimate.Discrete_estimate(
            arange(n_classes), n_classes, weighting=total)
            
        self.classes = [ ]
        for i in xrange(suggested_assignments.shape[1]):
             self.classes.append( prior[0](data, weighting=weighted_assignments[:,i], *prior[1:]) )
        
        neglog_likelihood_per_class = repeat([-log(self.class_likelihoods.probability)],
                                             len(data))
        
        for i in xrange(n_classes):
            neglog_likelihood_per_class[:,i] += self.classes[i].neglog_data()
        
        
        #self._neglog_data = -log(sum(exp(-neglog_likelihood_per_class),1))
        # ... however the exp can cause underflows and inaccuracy, so:
        min_neglog_likelihood = minimum.reduce(neglog_likelihood_per_class, 1)
        excess = exp( min_neglog_likelihood[:,NewAxis] - neglog_likelihood_per_class )
        self._neglog_data = min_neglog_likelihood - log(sum(excess,1))
        
        excess /= sum(excess,1)[:,NewAxis]
        self.assignments = excess
        
        for item in self.classes:
            self.has_prior = self.has_prior and item.has_prior
            self.violates_prior = self.violates_prior or item.violates_prior
        
    def __repr__(self):
        result = estimate.Estimate.__repr__(self)[1:]
        
        for i, class_ in enumerate(self.classes):
            result += "\n% 3.0f%% " % (self.class_likelihoods.probability[i]*100.0)
            result += repr(class_).replace("\n","\n     ")
        
        return result
    
    def dimensions(self):
        result = self.class_likelihoods.dimensions()
        
        for class_ in self.classes:
            result += class_.dimensions()
        
        return result
        
    def neglog_prior(self):
        # Likelihood of this many classes = 2^-n
        result = len(self.classes) * log(2.0)
        
        result += self.class_likelihoods.neglog_prior()
        
        for class_ in self.classes:
            result += class_.neglog_prior()
        
        # All possible orderings of classes are considered equivalent
        # This allows a few bits to be reclaimed
        result -= sum(log(arange(1,len(self.classes)+1)))
            
        return result
    
    def neglog_fisher(self):
        result = self.class_likelihoods.neglog_fisher()
        
        for class_ in self.classes:
            result += class_.neglog_fisher()
            
        return result
    
    def neglog_data(self):
        return self._neglog_data



def _destroy_class(assignments, index):
    """ Remove a column from an assignment matrix. """
    
    assignments = concatenate((assignments[:,:index], assignments[:,index+1:]), 1)
    
    summation = sum(assignments,1)
    
    # Avoid divide-by-zero
    zeros = (summation <= 0.0)
    if sometrue(zeros):
        assignments = transpose(where(zeros, 1.0, transpose(assignments)))
        summation = where(zeros, assignments.shape[1], summation)
        
    assignments /= summation[:,NewAxis]
    return assignments


def _split_class(assignments, index):
    """ Add a new column to an assignment matrix by splitting an existing column
        at random. """
        
    assignments = concatenate((assignments[:,:index+1], assignments[:,index:]),1)
    
    amount = random.random()
    for i in xrange(assignments.shape[0]):
        if random.random() < amount:
            assignments[i,index] = 0.0
        else:
            assignments[i,index+1] = 0.0
    
    return assignments


def _prune(assignments, weighting):
    """ Remove any columns that are too small from an assignment matrix. """
    
    while 1:
        for i in xrange(assignments.shape[1]):
            if sum(assignments[:,i] * weighting) < 2.0:
                assignments = _destroy_class(assignments, i)
                break
        else:
            return assignments
    
            
def Mixture_estimate(data, prior, iterations=1000, assignments=None, weighting=None):
    """ Construct a mixture estimate.
    
        data : data suitable for class estimator given by "prior"
                
        prior: (Estimate class, prior parameters... )
            Specification of the Estimate class and prior parameters to use 
            when modelling the data.
            
        iterations : integer
            Number of iterations to use in refining the estimate.
                
        assignments : N-data x N-classes array of float
            If specified, an initial guess as to the assignments of data to classes.
            You can use a previous _Mixture_estimate.assignments here to perform further 
            iterations on that model.
    
        Returns a _Mixture_estimate.
        
        TODO: The algorithm used here is rather simpler than that of Snob, and has not been
              thoroughly tested.
    
    """

    if assignments is None:
        assignments = ones((len(data),1), Float)
        
    if weighting is None:
        weighting = ones(len(data), Float)

    best_length = 1e100
    best_model = None
    best_assignments = None
    
    expected_length = None
    
    for i in xrange(iterations):
        assignments = _prune(assignments, weighting)
        model = _Mixture_estimate(data, prior, assignments, weighting)
        
        permutation = argsort(model.class_likelihoods.probability)
        assignments = take(model.assignments, permutation[::-1], 1)
        
        length = model.length()
        
        if length < best_length:
            best_length = length
            best_model = model
            best_assignments = assignments
            
        if expected_length is not None and abs(length-expected_length) < 0.001:
            expected_length = None
            if random.random() < 0.1:
                assignments = best_assignments
                expected_length = best_length
            elif assignments.shape[1] > 1 and random.random() < 0.5:
                assignments = _destroy_class(assignments, random.randrange(assignments.shape[1]))
            else:
                assignments = _split_class(assignments, random.randrange(assignments.shape[1]))
        else:
            expected_length = length

    return best_model    



if __name__ == "__main__":
    # Run tests if not "import"ed.

    data = records.array([
        concatenate([random_array.normal(0.0, 10.0, 10),random_array.normal(20.0,2.0,90)]),
        random_array.normal(5.0, 10.0, 100)
    ])
    
    priors = [
        (estimate.Gaussian_estimate, (-30.0, 30.0), 1.0, 30.0),
        (estimate.Gaussian_estimate, (-30.0, 30.0), 1.0, 30.0)
    ]
    
    if 1:
        class_est = Composite_estimate(data, priors)
        print class_est
        print
    
    if 1:
        mixture_est = Mixture_estimate(data, (Composite_estimate, priors), 100)
        print mixture_est
        print
    
    if 1:
        weighting = zeros(len(data), Float)
        weighting[:10] = 1.0
        mixture_est = Mixture_estimate(data, (Composite_estimate, priors), 100, weighting=weighting)
        print mixture_est
        print
    
    data = transpose(array([data.field(0), data.field(1)]))
    
    if 1:
        prior = (estimate.Hidden_factor_estimate, 
                 array([[-30.0,30.0],[-30.0,30.0]]), array([1.0,1.0]), array([30.0,30.0]))
        mixture_est = Mixture_estimate(data, prior, 10)
        print mixture_est
        print
    
    if 1:
        prior = (estimate.Multivariate_gaussian_estimate, 1.0) 
        mixture_est = Mixture_estimate(data, prior, 100)
        print mixture_est
        print