// Copyright (C) 2002 Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "Trainer.h"

namespace Torch {

Trainer::Trainer(Machine *machine_, DataSet *data_)
{
  machine = machine_;
  data = data_;
}

void Trainer::test(List *measurers)
{
  DataSet **datas;
  Measurer ***mes;
  int *n_mes;
  int n_datas;

  message("Trainer: testing");

  extractMeasurers(measurers, NULL, &datas, &mes, &n_mes, &n_datas);

  for(int andrea = 0; andrea < n_datas; andrea++)
  {
    DataSet *dataset = datas[andrea];

    for(int i = 0; i < n_mes[andrea]; i++)
      mes[andrea][i]->reset();

    for(int t = 0; t < dataset->n_examples; t++)
    {
      dataset->setExample(t);
      machine->forward(dataset->inputs);
    
      for(int i = 0; i < n_mes[andrea]; i++)
        mes[andrea][i]->measureEx();
    }
  
    for(int i = 0; i < n_mes[andrea]; i++)
      mes[andrea][i]->measureIter();

    for(int i = 0; i < n_mes[andrea]; i++)
      mes[andrea][i]->measureEnd();
  }

  deleteExtractedMeasurers(datas, mes, n_mes, n_datas);
}

void Trainer::crossValidate(int k_fold, List *train_measurers, List *test_measurers, List *cross_valid_measurers)
{
  int *mix_subset = (int *)xalloc(sizeof(int)*data->n_examples);
  getShuffledIndices(mix_subset, data->n_examples);
  data->pushSubset(mix_subset, data->n_examples);

  List *measurers_ = cross_valid_measurers;
  while(measurers_)
  {
    ((Measurer *)measurers_->ptr)->reset();
    measurers_ = measurers_->next;
  }

  int taille_subset = data->n_examples/k_fold;
  int *test_subset = (int *)xalloc(sizeof(int)*(taille_subset+data->n_examples%k_fold));
  int *train_subset = (int *)xalloc(sizeof(int)*(data->n_examples-taille_subset));

  for(int i = 0; i < k_fold; i++)
  {
    int n_train_subset = 0;
    int n_test_subset = 0;
    
    for(int j = 0; j < i*taille_subset; j++)
      train_subset[n_train_subset++] = j;
    for(int j = i*taille_subset; j < (i+1)*taille_subset; j++)
      test_subset[n_test_subset++] = j;
    if(i == k_fold-1)
    {
      for(int j = (i+1)*taille_subset; j < data->n_examples; j++)
        test_subset[n_test_subset++] = j;
    }
    else
    {
      for(int j = (i+1)*taille_subset; j < data->n_examples; j++)
        train_subset[n_train_subset++] = j;
    }

    data->pushSubset(train_subset, n_train_subset);
    machine->reset();
    train(train_measurers);
    data->popSubset();

    data->pushSubset(test_subset, n_test_subset);
    test(test_measurers);
    data->popSubset();

    measurers_ = cross_valid_measurers;
    while(measurers_)
    {
      ((Measurer *)measurers_->ptr)->measureIter();
      measurers_ = measurers_->next;
    }
  }

  measurers_ = cross_valid_measurers;
  while(measurers_)
  {
    ((Measurer *)measurers_->ptr)->measureEnd();
    measurers_ = measurers_->next;
  }
  
  data->popSubset();
  free(test_subset);
  free(train_subset);
  free(mix_subset);
}

// A vos risques et perils...
void Trainer::testExample(List *measurers, int t)
{
  if(!measurers)
    return;

  DataSet *dataset = ((Measurer *)(measurers->ptr))->data;
  dataset->setExample(t);
  machine->forward(dataset->inputs);

  while(measurers)
  {
    Measurer *mes = (Measurer *)measurers->ptr;
    mes->reset();
    mes->measureEx();
    measurers = measurers->next;
  }
}

Trainer::~Trainer()
{
}

void extractMeasurers(List *measurers, DataSet *train, DataSet ***datas, Measurer ****mes, int **n_mes, int *n_datas)
{
  DataSet **datas_;
  List *measurers_ = measurers;
  Measurer ***mes_;
  int *n_mes_;

  int n_measurers = 0;
  while(measurers_)
  {
    n_measurers++;
    measurers_ = measurers_->next;
  }

//  printf("%d measurers found\n", n_measurers);

  // bourrin... au cas tout != train... et tous =!
  n_measurers++;

  // Alloc boeuf
  datas_ = (DataSet **)xalloc(sizeof(DataSet *)*n_measurers);
  mes_ = (Measurer ***)xalloc(sizeof(Measurer **)*n_measurers);
  n_mes_ = (int *)xalloc(sizeof(int)*n_measurers);
  for(int i = 0; i < n_measurers; i++)
  {
    mes_[i] = (Measurer **)xalloc(sizeof(Measurer *)*n_measurers);
    n_mes_[i] = 0;
  }

  // Cherche les datas
  int n_datas_ = 0;

  if(train)
  {
    datas_[0] = train;
    n_datas_++;
  }

  measurers_ = measurers;
  while(measurers_)
  {
    DataSet *curr_dat = ((Measurer *)measurers_->ptr)->data;

    bool already_exists = false;
    for(int i = 0; i < n_datas_; i++)
    {
      if(datas_[i] == curr_dat)
      {
        already_exists = true;
        break;
      }
    }

    if(!already_exists)
    {
      datas_[n_datas_] = curr_dat;
      n_datas_++;
    }
    measurers_ = measurers_->next;
  }
  
  // Cherche les measurers associes aux datas
  measurers_ = measurers;
  while(measurers_)
  {
    DataSet *curr_dat = ((Measurer *)measurers_->ptr)->data;

    int the_i = -1;
    for(the_i = 0; the_i < n_datas_; the_i++)
    {
      if(datas_[the_i] == curr_dat)
        break;
    }

    mes_[the_i][n_mes_[the_i]++] = (Measurer *)measurers_->ptr;
    measurers_ = measurers_->next;
  }

  *datas = datas_;
  *mes = mes_;
  *n_mes = n_mes_;
  *n_datas = n_datas_;
}

void deleteExtractedMeasurers(DataSet **datas, Measurer ***mes, int *n_mes, int n_datas)
{
  int n_measurers = 0;
  for(int i = 0; i < n_datas; i++)
    n_measurers += n_mes[i];

  // voir plus haut...
  n_measurers++;

  free(datas);
  for(int i = 0; i < n_measurers; i++)
    free(mes[i]);
  free(mes);
  free(n_mes);  
}

void Trainer::loadFILE(FILE *file)
{
  data->loadFILE(file);
  machine->loadFILE(file);
}

void Trainer::saveFILE(FILE *file)
{
  data->saveFILE(file);
  machine->saveFILE(file);
}

}

