// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


#include "Distribution.h"
#include "log_add.h"

namespace Torch {

Distribution::Distribution() : GradientMachine()
{
  n_outputs = 1; // should point on the nll if one wants to optimize it!
  // by default.
  n_observations = 0;
  n_inputs = 0;
}

void Distribution::allocateMemory()
{
  GradientMachine::allocateMemory();
}

void Distribution::freeMemory()
{
  GradientMachine::freeMemory();
}

void Distribution::init()
{
  GradientMachine::init();
  //done in GradientMachine::init():   
  //allocateMemory();
  //do the reset yourself in the main, whenever you want
  //reset();
}

void Distribution::reset()
{
}

int Distribution::numberOfParams()
{
  return 0;
}

real Distribution::logProbability(List *inputs)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;
  real ll = 0;
  for (int i=0;i<ex->n_frames;i++) {
    setFrameExample(ex,i);
    if (ex->inputs)
      in = ex->inputs[ex->current_frame];
    if (ex->observations)
      obs = ex->observations[ex->current_frame];
    ll += frameLogProbability(obs,in,ex->current_frame);
  }
  return ll;
}

real Distribution::viterbiLogProbability(List *inputs)
{
  return logProbability(inputs);
}

real Distribution::frameLogProbability(real *observations, real *inputs, int t)
{
  return LOG_ZERO;
}

void Distribution::frameGenerate(real *observations, real *inputs,int t)
{
}

void Distribution::frameExpectation(real *observations, real *inputs,int t)
{
}

void Distribution::iterInitialize()
{
  eMIterInitialize();
}

void Distribution::eMIterInitialize()
{
}

void Distribution::eMSequenceInitialize(List* inputs)
{
}

void Distribution::sequenceInitialize(List* inputs)
{
}

void Distribution::eMAccPosteriors(List *inputs, real log_posterior)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;
  for (int i=0;i<ex->n_frames;i++) {
    setFrameExample(ex,i);
    if (ex->inputs)
      in = ex->inputs[ex->current_frame];
    if (ex->observations)
      obs = ex->observations[ex->current_frame];
    frameEMAccPosteriors(obs,log_posterior,in,ex->current_frame);
  }
}

void Distribution::viterbiAccPosteriors(List *inputs, real log_posterior)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;
  for (int i=0;i<ex->n_frames;i++) {
    setFrameExample(ex,i);
    if (ex->inputs)
      in = ex->inputs[ex->current_frame];
    if (ex->observations)
      obs = ex->observations[ex->current_frame];
    frameViterbiAccPosteriors(obs,log_posterior,in,ex->current_frame);
  }
}

void Distribution::frameEMAccPosteriors(real *observations, real log_posterior, real *inputs, int t)
{
}

void Distribution::frameViterbiAccPosteriors(real *observations, real log_posterior, real *inputs, int t)
{
}

void Distribution::eMUpdate()
{
}


void Distribution::decode(List *inputs)
{
   *(real*)outputs->ptr = 0;
}

void Distribution::forward(List *inputs)
{
   sequenceInitialize(inputs);
   log_probability = logProbability(inputs);
   *(real*)outputs->ptr = -log_probability;
}

void Distribution::eMForward(List *inputs)
{
   eMSequenceInitialize(inputs);
   log_probability = logProbability(inputs);
   *(real*)outputs->ptr = -log_probability;
}

void Distribution::viterbiForward(List *inputs)
{
   eMSequenceInitialize(inputs);
   log_probability = viterbiLogProbability(inputs);
   *(real*)outputs->ptr = -log_probability;
}

void Distribution::backward(List *inputs, real *alpha)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  real* in = NULL;
  real* obs = NULL;
  for (int i=0;i<ex->n_frames;i++) {
    setFrameExample(ex,i);
    if (ex->inputs)
      in = ex->inputs[ex->current_frame];
    if (ex->observations)
      obs = ex->observations[ex->current_frame];
    frameBackward(obs,alpha,in,ex->current_frame);
  }
}

void Distribution::viterbiBackward(List *inputs, real *alpha)
{
  backward(inputs,alpha);
}

void Distribution::frameBackward(real *observations, real *alpha, real *inputs, int t)
{
}

void Distribution::saveFILE(FILE *file)
{
  List *liste = params;

  while(liste)
  {
    xfwrite(liste->ptr, sizeof(real), liste->n, file);
    liste = liste->next;
  }
}

void Distribution::loadFILE(FILE *file)
{
  List *liste = params;

  while(liste)
  {
    xfread(liste->ptr, sizeof(real), liste->n, file);
    liste = liste->next;
  }
  eMIterInitialize();
}

Distribution::~Distribution()
{
}

}

