/* pmmle_wave_mc.c
 *
 * Copyright (C) 2007 Stephane Germain
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */

/**
   \file 
   \brief Functions to estimate the multiple choice item
   parameters by PMMLE (Penalized Maximal Marginal Likelihood).

   The functional estimations are done by a wavelet decomposition, and
   then by using a root finding algorithm on the wavelet coefficients.

   \author Stephane Germain <germste@gmail.com>
*/

#include "libirt.h"

#include <stdio.h>
#include <math.h>
#include <gsl/gsl_errno.h>
#include <gsl/gsl_multiroots.h>
#include <gsl/gsl_linalg.h>
#include <gsl/gsl_wavelet.h>

/**
   \brief Compute the options characteristic curves from the logits.

   @param[in] logits A matrix(logits x classes).
   @param[out] probs A matrix(options x classes).
*/
void
probs_from_logits (gsl_matrix * logits, gsl_matrix * probs)
{
  int i, k;
  double sum;
  int nbr_option = probs->size1;
  int nbr_quad = probs->size2;
  int nbr_logit = logits->size1;

  for (k = 0; k < nbr_quad; k++)
    {
      /* first compute sum(exp(-logit)) */
      sum = 1;
      for (i = 0; i < nbr_logit; i++)
	sum += exp(-gsl_matrix_get (logits, i, k));
      /* now set the option curves */
      for (i = 0; i < nbr_logit; i++)
	{
	  gsl_matrix_set (probs, i, k, 
			  1/(exp(gsl_matrix_get (logits, i, k))*sum));
	}
      /* the the last option */
      gsl_matrix_set (probs, nbr_option-1, k, 1/sum);
    }
}

/**
   \brief Compute the logit from the options characteristic curves.

   @param[in] probs A matrix(options x classes).
   @param[out] logits A matrix(logits x classes).
*/
void
logits_from_probs (gsl_matrix * probs, gsl_matrix * logits)
{
  int i, k;
  int nbr_option = probs->size1;
  int nbr_quad = probs->size2;
  int nbr_logit = logits->size1;

  for (i = 0; i < nbr_logit; i++)
    {
      for (k = 0; k < nbr_quad; k++)
	{
	  gsl_matrix_set (logits, i, k, 
			  log (gsl_matrix_get (probs, nbr_option-1, k)
			       /  gsl_matrix_get (probs, i, k)));
	}
    }
}

/**
   \brief Used to passed extra parameter to \em mple_wave_mc_fdfdf2.

   This is used to comply with the root finding functions in
   the gsl (GNU scientific library).
*/
typedef struct
{
  /** \brief The prior weights of each quadrature class. */
  gsl_vector *quad_weights;

  /** \brief The expected number of subject in each quadrature classes. */
  gsl_vector *quad_sizes;

  /** \brief The expected number of subject in each quadrature classes 
      and in each options. */
  gsl_matrix *quad_freqs;

  /** \brief The penalizing factor */
  double smooth_factor;

  /* I declare those here to save on the malloc and on some computation */

  /** \brief Used to do the inverse wavelet transformation. */
  gsl_wavelet *wave;

  /** \brief Used to do the inverse wavelet transformation. */
  gsl_wavelet_workspace *work;

  /** \brief The wavelets evaluated at each middle point. */
  gsl_matrix *wavelets;

  /** \brief The (2)th derivative of the wavelets evaluated at each middle point. */
  gsl_matrix *deriv_wavelets;

  /** \brief A place to store the inverse wavelet transform (ie the logit). */
  gsl_matrix *logits;

  /** \brief A place to store the (2)th derivative of the logit. */
  gsl_matrix *deriv_logits;

  /** \brief A place to store the option response function. */
  gsl_matrix *probs;

} mple_wave_mc_struct;

/**
   \brief Compute the gradient and Hessian of the wavelet coefficients.

   @param[in] par_wave The wavelet coefficients.
   @param[in] params The extra parameter to passes to the function.
   @param[out] df The gradient of the penalized log likelihood.
   @param[out] df2 The Hessian of the penalized log likelihood.

   This function is not used directly by the root finding functions,
   but by others functions that comply with the gsl.

   \return GSL_SUCCESS for success.
*/
int
mple_wave_mc_fdfdf2 (const gsl_vector * par_wave, void *params,
		     double * f, gsl_vector * df, gsl_matrix * df2)
{
  /* gsl_vector *quad_weights = ((mple_wave_mc_struct *) params)->quad_weights; */
  gsl_vector *quad_sizes = ((mple_wave_mc_struct *) params)->quad_sizes;
  gsl_matrix *quad_freqs = ((mple_wave_mc_struct *) params)->quad_freqs;
  double smooth_factor = ((mple_wave_mc_struct *) params)->smooth_factor;
  gsl_wavelet *wave = ((mple_wave_mc_struct *) params)->wave;
  gsl_wavelet_workspace *work = ((mple_wave_mc_struct *) params)->work;
  gsl_matrix *wavelets = ((mple_wave_mc_struct *) params)->wavelets;
  gsl_matrix *deriv_wavelets = ((mple_wave_mc_struct *) params)->deriv_wavelets;
  gsl_matrix *logits = ((mple_wave_mc_struct *) params)->logits;
  gsl_matrix *deriv_logits =
    ((mple_wave_mc_struct *) params)->deriv_logits;
  gsl_matrix *probs = ((mple_wave_mc_struct *) params)->probs;
  double size, freq, deriv_logit, grad, hess, weight, prob;
  int j, i, k, nbr_logit = logits->size1, nbr_quad = quad_freqs->size2, s, t;

  /* reset to zero */
  if (f) *f = 0;
  if (df)
    gsl_vector_set_all (df, 0.0);
  if (df2)
    gsl_matrix_set_all (df2, 0.0);

  /* retrieve the logit and theirs second derivatives */
  for (i = 0; i < nbr_logit; i++)
    {
      /* copy the wavelet coefficients and do the inverse wavelet transform */
      for (k = 0; k < nbr_quad; k++)
	{
	  gsl_matrix_set (logits, i, k, gsl_vector_get (par_wave, i*nbr_quad+k));
	}
      gsl_wavelet_transform_inverse (wave, &(logits->data[i*nbr_quad]), 1, nbr_quad, work);

      for (k = 0; k < nbr_quad; k++)
	{
	  /* reset the logit if too big in absolute value */
	  if (gsl_matrix_get (logits, i, k) > VERY_BIG_LOGIT)
	    gsl_matrix_set (logits, i, k, VERY_BIG_LOGIT);
	  if (gsl_matrix_get (logits, i, k) < -VERY_BIG_LOGIT)
	    gsl_matrix_set (logits, i, k, -VERY_BIG_LOGIT);

	  /* compute the (2)th derivative of the logit */
	  deriv_logit = 0;
	  for (s = 0; s < nbr_quad; s++)
	    deriv_logit += gsl_vector_get (par_wave, s+i*nbr_quad)
	      * gsl_matrix_get (deriv_wavelets, s, k);
	  gsl_matrix_set (deriv_logits, i, k, deriv_logit);
	}
    }

  /* compute the option response functions */
  probs_from_logits(logits, probs);

  /* sum over the classes */
  for (k = 0; k < nbr_quad; k++)
    {
      /* for each logit (one less than the number of option) */
      for (i = 0; i < nbr_logit; i++)
	{
	  prob = gsl_matrix_get (probs, i, k);
	  size = gsl_vector_get (quad_sizes, k);
	  weight = 1; /*gsl_vector_get (quad_weights, k);*/
	  freq = gsl_matrix_get (quad_freqs, i, k);
	  deriv_logit = gsl_matrix_get (deriv_logits, i, k);

	  /* update the llk */
	  if (f) {
	    *f += freq * log(prob)
	      - smooth_factor * weight * deriv_logit * deriv_logit;
	    if(gsl_isnan((*f))) return GSL_FAILURE;
	  }

	  /* update the gradient */
	  if (df)
	    /* for each wavelets coefficients of the logit i */
	    for (s = 0; s < nbr_quad; s++)
	      {
		grad = gsl_vector_get (df, s+i*nbr_quad);
		grad -= (freq - size * prob) * gsl_matrix_get (wavelets, s, k);
		grad -= smooth_factor * weight * 2 * deriv_logit
		  * gsl_matrix_get (deriv_wavelets, s, k);
		gsl_vector_set (df, s+i*nbr_quad, grad);
		if(gsl_isnan(grad)) return GSL_FAILURE;
	      }
	  
	  /* update the Hessian */
	  if (df2)
	    {
	    /* for each wavelets coefficients of the logit i */
	      for (s = 0; s < nbr_quad; s++)
		/* for each logit j */
		for (j = 0; j <= i; j++)
		  /* for each wavelets coefficients of the logit j */
		  for (t = 0; t < ((j==i)?(s+1):(nbr_quad)); t++)
		    {
		      hess = gsl_matrix_get (df2, s+i*nbr_quad, t+j*nbr_quad);
		      hess -= prob * ((i==j) - gsl_matrix_get (probs, j, k)) * size 
			* gsl_matrix_get (wavelets, s, k)
			* gsl_matrix_get (wavelets, t, k);
		      hess -= smooth_factor * weight * 2 * (i==j)
			* gsl_matrix_get (deriv_wavelets, s, k)
			* gsl_matrix_get (deriv_wavelets, t, k);
		      gsl_matrix_set (df2, s+i*nbr_quad, t+j*nbr_quad, hess);
		      if(gsl_isnan(hess)) return GSL_FAILURE;
		    }
	    }
	}

      /* last option */
      prob = gsl_matrix_get (probs, i, k);
      freq = gsl_matrix_get (quad_freqs, i, k);
      /* update the llk */
      if (f) {
	*f += freq * log(prob);
	if(gsl_isnan((*f))) return GSL_FAILURE;
      }

    }
  
  /* copy the lower half of the Hessian to the upper half */
  if (df2)
    for (i = 0; i < nbr_logit*nbr_quad; i++)
      for (j = 0; j < i; j++)
	gsl_matrix_set (df2, j, i, gsl_matrix_get (df2, i, j));

  return GSL_SUCCESS;
}

/**
   \brief Compute the gradient and Hessian of the wavelet coefficients.

   @param[in] par_wave The wavelet coefficients.
   @param[in] params The extra parameter to passes to the function.
   @param[out] df The gradient of the penalized log likelihood.
   @param[out] df2 The Hessian of the penalized log likelihood.

   This function is just a wrapper around \em mple_wavefdfdf2
   to be used by the root finding functions in the gsl.

   \return GSL_SUCCESS for success.
*/
int
mple_wave_mc_dfdf2 (const gsl_vector * par_wave, void *params,
		 gsl_vector * df, gsl_matrix * df2)
{
  return mple_wave_mc_fdfdf2 (par_wave, params, NULL, df, df2);
}

/**
   \brief Compute the gradient of the wavelet coefficients.

   @param[in] par_wave The wavelet coefficients.
   @param[in] params The extra parameter to passes to the function.
   @param[out] df The gradient of the penalized log likelihood.

   This function is just a wrapper around \em mple_wavefdfdf2
   to be used by the root finding functions in the gsl.

   \return GSL_SUCCESS for success.
*/
int
mple_wave_mc_df (const gsl_vector * par_wave, void *params, gsl_vector * df)
{
  return mple_wave_mc_fdfdf2 (par_wave, params, NULL, df, NULL);
}

/**
   \brief Compute the Hessian of the wavelet coefficients.

   @param[in] par_wave The wavelet coefficients.
   @param[in] params The extra parameter to passes to the function.
   @param[out] df2 The Hessian of the penalized log likelihood.

   This function is just a wrapper around \em mple_wavefdfdf2
   to be used by the root finding functions in the gsl.

   \return GSL_SUCCESS for success.
*/
int
mple_wave_mc_df2 (const gsl_vector * par_wave, void *params, gsl_matrix * df2)
{
  return mple_wave_mc_fdfdf2 (par_wave, params, NULL, NULL, df2);
}

/**
   \brief Does the maximization step of the EM algorithm to
   estimate the response functions by PMMLE (Penalized Maximum Marginal Likelihood)
   of one multiple choice item.

   @param[in] max_iter The maximum number of Newton iterations performed for each item.
   @param[in] prec The desired precision of each wavelet parameter estimate.
   @param[in] params The extra parameter to passes to the function.
   @param[in,out] probs A matrix(items x classes) with the estimated response functions.
   They should be initialize first.
   @param[out] probs_stddev matrix(items x classes) with the standard error
   of the logit response functions.
   @param[out] mllk The maximum log likelihood.

   \return 1 if the item converge, 0 otherwise.
   
   \warning The memory for the response functions should be allocated before.
*/
int
mple_wave_mc (int max_iter, double prec,
	      mple_wave_mc_struct * params,
	      gsl_matrix * probs, gsl_matrix * probs_stddev,
	      double * mllk)
{
  const gsl_multiroot_fdfsolver_type *algo;
  gsl_multiroot_fdfsolver *solver;
  int status, iter, nbr_option = probs->size1, nbr_logit = nbr_option-1,
    i, j, k, s, t, ret_val, nbr_quad = probs->size2,
    nbr_par = nbr_logit * nbr_quad;
  gsl_multiroot_function_fdf FDF;

  /* the vector with all the wavelets parameters */
  gsl_vector * par_wave = gsl_vector_alloc(nbr_par);

  /* used to compute the standard errors */
  gsl_matrix *df2 = gsl_matrix_alloc (nbr_par, nbr_par);
  gsl_permutation *lu_perm = gsl_permutation_alloc (nbr_par);
  gsl_matrix *inv_df2 = gsl_matrix_alloc (nbr_par, nbr_par);
  gsl_matrix *G = gsl_matrix_alloc (nbr_logit, nbr_option);
  gsl_matrix *COV = gsl_matrix_alloc (nbr_logit, nbr_logit);
  int lu_sign;
  double var, sum_grad, prob;

  /* allocate the memory to compute the logits and its derivatives */
  params->logits = gsl_matrix_alloc (nbr_logit, nbr_quad);
  params->deriv_logits = gsl_matrix_alloc (nbr_logit, nbr_quad);
  params->probs = probs;

  /* initalize the function to solve */
  FDF.f = &mple_wave_mc_df;
  FDF.df = &mple_wave_mc_df2;
  FDF.fdf = &mple_wave_mc_dfdf2;
  FDF.n = nbr_par;
  FDF.params = params;

  /* select the algorithm to used */
  algo = gsl_multiroot_fdfsolver_gnewton;
  /* allocate the solver */
  solver = gsl_multiroot_fdfsolver_alloc (algo, nbr_par);

  ret_val = 0;

  /* set the starting values ... */
  /* ... use the "logit" transform */
  logits_from_probs (probs, params->logits);

  for (i = 0; i < nbr_logit; i++)
    {
      /* ... copy the result */
      for (k = 0; k < nbr_quad; k++)
	{
	  gsl_vector_set (par_wave, i*nbr_quad+k, gsl_matrix_get (params->logits, i, k));
	}
      /* ... and use a wavelet transform */
      gsl_wavelet_transform_forward (params->wave, &(par_wave->data[i*nbr_quad]), 1,
				     nbr_quad, params->work);
    }

  /* set the solver */
  gsl_multiroot_fdfsolver_set (solver, &FDF, par_wave);

  /* iterate the solver */
  iter = 0;
  do
    {
      iter++;
      
      /* check if the hessian is singular */
      status = 0;
      for (k = 0; k < nbr_par; k++)
	{
	  if(0 == gsl_matrix_get (solver->J, k, k))
	    {
	      status = GSL_EBADFUNC;
	      break;
	    }
	}
      if (status) break;

      status = gsl_multiroot_fdfsolver_iterate (solver);
      
      if (libirt_verbose > 7)
	{
	  sum_grad = 0;
	  for (k = 0; k < nbr_par; k++) sum_grad += fabs(gsl_vector_get(solver->f, k));
	  printf ("\n At N-R iteration %d sum|grad(PML)| is %8.2e.\n",
		  iter, sum_grad);
	}
      
      if (status)
	break;
      
      /* test for convergence */
      status = gsl_multiroot_test_delta (solver->dx, solver->x, prec, 0);
      /* status = gsl_multiroot_test_residual (solver->f, prec); */
      
    }
  while (status == GSL_CONTINUE && iter < max_iter);

  /* compute the maximum log likelihood to return */
  mple_wave_mc_fdfdf2 (solver->x, params, mllk, NULL, NULL);

  /* check if this item converged */
  if (status != GSL_SUCCESS)
    {
      ret_val++;
    }
  
  if (libirt_verbose > 3)
    {
      if (status == GSL_CONTINUE)
	printf (" did not converged (max iter)");
      else if (status == GSL_EBADFUNC)
	printf (" did not converged (singular point)");
      else if (status == GSL_ENOPROG)
	printf (" did not converged (no progress)");
      else if (status == GSL_ENOPROGJ)
	printf (" did not converged (jacobian no progress)");
      else if (status == GSL_SUCCESS)
	printf (" converged (success)");
      else
	printf (" unknow status (%d)", status);
      printf (" after %d iterations.\n", iter);
      fflush (stdout);
    }
  
  /* copy the solution found */
  for (i = 0; i < nbr_logit; i++)
    {
      /* copy the wavelet coefficients and do the inverse wavelet transform */
      for (k = 0; k < nbr_quad; k++)
	{
	  gsl_matrix_set (params->logits, i, k, gsl_vector_get (solver->x, i*nbr_quad+k));
	}
      gsl_wavelet_transform_inverse (params->wave, &(params->logits->data[i*nbr_quad]), 
				     1, nbr_quad, params->work);
    }

  /* compute the option response functions */
  probs_from_logits(params->logits, probs);

  /* reset the probabilities inside the open interval (0,1) */
  for (i = 0; i < nbr_option; i++)
    for (k = 0; k < nbr_quad; k++)
      {
	prob = gsl_matrix_get(probs, i, k);
	if (prob < VERY_SMALL_PROB) prob = VERY_SMALL_PROB;
	if (prob > 1 - VERY_SMALL_PROB) prob = 1 - VERY_SMALL_PROB;
        gsl_matrix_set(probs, i, k, prob);
      }

  /* compute the standard errors */
  if (probs_stddev)
    {
      /* get the Hessian */
      mple_wave_mc_df2 (solver->x, params, df2);

      /* inverse it */
      gsl_linalg_LU_decomp (df2, lu_perm, &lu_sign);
      gsl_linalg_LU_invert (df2, lu_perm, inv_df2);

      /* for each classes */
      for (k = 0; k < nbr_quad; k++)
	{
	  /* compute the covariance between the logit 
	     and the derivative of the transformation from the logit to the probs 
	     each transformation (option) is in a column of G */
	  for (i = 0; i < nbr_logit; i++)
	    {
	      for (j = 0; j < nbr_logit; j++)
		{
		  /* covariance between the logit */
		  var = 0;
		  for (s = 0; s < nbr_quad; s++)
		    for (t = 0; t < nbr_quad; t++)
		      var -= gsl_matrix_get (params->wavelets, s, k)
			* gsl_matrix_get (params->wavelets, t, k)
			* gsl_matrix_get (inv_df2, i*nbr_quad+s, j*nbr_quad+t);
		  gsl_matrix_set(COV, i, j, var);
		  
		  /* derivative of the transformation */
		  if (i == j)
		    gsl_matrix_set(G, j, j, -gsl_matrix_get(probs, j, k)
				   * (1 - gsl_matrix_get(probs, j, k)));
		  else
		    gsl_matrix_set(G, i, j, gsl_matrix_get(probs, j, k)
				   * gsl_matrix_get(probs, j, k)
				   * exp(gsl_matrix_get(params->logits, j, k)
					 - gsl_matrix_get(params->logits, i, k))); 
		}
	      /* the derivatives for the transformation of the last option */
	      j = nbr_option-1;
	      gsl_matrix_set(G, i, j, gsl_matrix_get(probs, j, k)
			     * gsl_matrix_get(probs, j, k)
			     * exp(-gsl_matrix_get(params->logits, i, k)));
	    }

	  /* compute the variance of each option curves by Taylor approximation */
	  for (i = 0; i < nbr_option; i++)
	    {
	      var = 0;
	      for(s = 0; s < nbr_logit; s++)
		for(t = 0; t < nbr_logit; t++)
		  var += gsl_matrix_get(G, s, i) * gsl_matrix_get(G, t, i)
		    * gsl_matrix_get(COV, s, t);
	      gsl_matrix_set(probs_stddev, i, k, sqrt(var));
	    }
	}
    }

  /* free the memory */
  gsl_multiroot_fdfsolver_free (solver);
  gsl_vector_free (par_wave);
  gsl_matrix_free (df2);
  gsl_matrix_free (G);
  gsl_matrix_free (COV);
  gsl_permutation_free (lu_perm);
  gsl_matrix_free (params->logits);
  gsl_matrix_free (params->deriv_logits);

  return ret_val;
}

/**
   \brief Estimate the options response functions by PMMLE (Penalized Maximum Marginal Likelihood).

   @param[in] max_em_iter The maximum number of EM iterations. At least 20 iteration are made.
   @param[in] max_nr_iter The maximum number of Newton iterations performed
   for each item at each EM iteration.
   @param[in] prec The relative change in the likelihood to stop the EM algorithm.
   This value divided by 10 is also the desired precision of each parameter estimate.
   @param[in] smooth_factor The factor to the penality term.
   @param[in] patterns A matrix(patterns x options) of binary responses.
   @param[in] counts A vector(patterns) with the count of each pattern.
   If NULL the counts are assumed to be all 1.
   @param[in] quad_points A vector(classes) with the middle points of each quadrature class.
   @param[in] quad_weights A vector(classes) with the prior weights of each quadrature class.
   @param[in] items_pos A vector(items) with the position of the first option of each item
   in patterns (and probs).
   @param[in] nbr_options A vector(items) with the number of option of each item
   in patterns (and probs).
   @param[in,out] probs A matrix(options x classes) with the estimated response functions.
   They should be initialize first.
   @param[out] probs_stddev matrix(items x classes) with the standard error
   of the logit response functions.
   @param[in] ignore A vector(items) of ignore flag.
   @param[out] nbr_notconverge The number of items that didn't converged.
   @param[out] notconverge A vector(items) of flag set for the items that didn't converged.
   @param[in] adjust_weights Controls whether adjust the quadrature weights after each iteration.

   \return 1 if the relative change in the maximum log likelihood was less than prec
   else 0.
   
   \warning The memory for the outputs should be allocated before.
*/
int
em_mple_wave_mc (int max_em_iter, int max_nr_iter, double prec, double smooth_factor, 
		 gsl_matrix_int * patterns, gsl_vector * counts,
		 gsl_vector * quad_points, gsl_vector * quad_weights, 
		 gsl_vector_int * items_pos, gsl_vector_int * nbr_options,
		 gsl_matrix * probs, gsl_matrix * probs_stddev,
		 gsl_vector_int * ignore,
		 int * nbr_notconverge, gsl_vector_int * notconverge,
		 int adjust_weights)
{
  int em_iter, nbr_quad, nbr_pattern, nbr_item, nbr_option_tot, nbr_option,
    nbr_option_max, ret_val, s, k, j, i, pos;
  double deriv, step, prob, nbr_subject, mllk, mllk_old=0, mllk_i;
  gsl_matrix *quad_freqs, *post;
  gsl_vector *quad_sizes;
  gsl_matrix_view quad_freqs_i, probs_i, probs_stddev_i;
  mple_wave_mc_struct params;

  nbr_quad = quad_points->size;
  nbr_pattern = patterns->size1;
  nbr_option_tot = patterns->size2;
  nbr_item = items_pos->size;

  params.quad_weights = quad_weights;

  nbr_subject = 0;
  /* count the number of subject */
  for(j = 0; j < nbr_pattern; j++)
    nbr_subject += counts ? gsl_vector_get(counts, j) : 1;

  /* find the maximum number of option */
  nbr_option_max = 0;
  for (i = 0; i < nbr_item; i++)
    {
      nbr_option = gsl_vector_int_get(nbr_options, i);
      if(nbr_option > nbr_option_max) nbr_option_max = nbr_option;
    }

  /* adjust the smoothing factor for the quadratures widths factor */
  smooth_factor *= (gsl_vector_get(quad_points, nbr_quad - 1) 
                    - gsl_vector_get(quad_points, 0)) / (nbr_quad - 1);
  params.smooth_factor = smooth_factor;

  /* allocate the memory */
  quad_freqs = gsl_matrix_alloc (nbr_option_tot, nbr_quad);
  quad_sizes = gsl_vector_alloc (nbr_quad);
  post = gsl_matrix_alloc (nbr_pattern, nbr_quad);
  params.wavelets = gsl_matrix_alloc (nbr_quad, nbr_quad);
  params.deriv_wavelets = gsl_matrix_alloc (nbr_quad, nbr_quad);

  /* initialize the wavelets transform */
  params.wave = gsl_wavelet_alloc (gsl_wavelet_daubechies, 10);
  params.work = gsl_wavelet_workspace_alloc (nbr_quad);

  /* compute the wavelets and theirs (2)th derivatives */
  /* allocate the memory */
  /* set to zero */
  gsl_matrix_set_all (params.wavelets, 0);
  for (s = 0; s < nbr_quad; s++)
    {
      /* select the wavelet */
      gsl_matrix_set (params.wavelets, s, s, 1);
      /* compute the wavelet */
      gsl_wavelet_transform_inverse (params.wave,
				     params.wavelets->data + s * nbr_quad,
				     1, nbr_quad, params.work);
      /* compute its (2)th derivatives */
      /** \todo Compute more accurates wavelets derivatives */
      for (k = 1; k < nbr_quad - 1; k++)
	{
	  deriv = gsl_matrix_get (params.wavelets, s, k + 1)
	    - 2 * gsl_matrix_get (params.wavelets, s, k)
	    + gsl_matrix_get (params.wavelets, s, k - 1);
	  step = (gsl_vector_get (quad_points, k + 1) -
		  gsl_vector_get (quad_points, k - 1)) / 2;
	  step *= step;
	  deriv /= step;
	  gsl_matrix_set (params.deriv_wavelets, s, k, deriv);
	}
      gsl_matrix_set (params.deriv_wavelets, s, k, deriv);
      gsl_matrix_set (params.deriv_wavelets, s, 0,
		      gsl_matrix_get (params.deriv_wavelets, s, 1));
    }

  /* reset the probabilities inside the open interval (0,1) */
  for (i = 0; i < nbr_option_tot; i++)
    for (k = 0; k < nbr_quad; k++)
      {
	prob = gsl_matrix_get(probs, i, k);
	if (prob < VERY_SMALL_PROB) prob = VERY_SMALL_PROB;
	if (prob > 1 - VERY_SMALL_PROB) prob = 1 - VERY_SMALL_PROB;
        gsl_matrix_set(probs, i, k, prob);
      }

  /* EM iterations */

  for (em_iter = 1; em_iter <= max_em_iter; em_iter++)
    {
      /* E (estimation) step */

      if (libirt_verbose > 2)
	printf ("\nEM iteration %d\n", em_iter);

      /* compute the posterior prob */
      posteriors_mc (patterns, probs, nbr_options, items_pos, quad_weights, post);

      /* compute the expected sizes and frequencies */
      frequencies (patterns, counts, post, probs, quad_sizes, quad_freqs);

      /* print debugging information */
      if (libirt_verbose > 5)
	{
	  for (i = 0; i < nbr_item; i++)
	    for (j = 0; j < gsl_vector_int_get(nbr_options, i); j++)
	      {
		pos = gsl_vector_int_get(items_pos,i);
		printf("Probabilities for option %d of item %d :\n", j+1, i+1);
		for (k = 0; k < nbr_quad; k++)
		  printf(" %8.2e", gsl_matrix_get(probs, pos+j, k));
		printf("\n");
	      }
	  for (j = 0; j < nbr_pattern; j++)
	    {
	      printf("Posterior for pattern %d :\n", j+1);
	      for (k = 0; k < nbr_quad; k++)
		printf(" %8.2e", gsl_matrix_get(post,j,k));
	      printf("\n");
	    }
	  printf("Sizes :\n");
	  for (k = 0; k < nbr_quad; k++)
	    printf(" %8.2e", gsl_vector_get(quad_sizes,k));
	  printf("\n");
	  for (i = 0; i < nbr_item; i++)
	    for (j = 0; j < gsl_vector_int_get(nbr_options, i); j++)
	      {
		pos = gsl_vector_int_get(items_pos,i);
		printf("Frequencies for option %d of item %d :\n", j+1, i+1);
		for (k = 0; k < nbr_quad; k++)
		  printf(" %8.2e", gsl_matrix_get(quad_freqs,pos+j,k));
		printf("\n");
	      }
	}

      /* the number of item that do not converge */
      *nbr_notconverge = 0;
      
      mllk = 0;

      /* M (maximisation) step */
      for (i = 0; i < nbr_item; i++)
	{
	  /* ignore the degenerate items */
	  if (ignore && gsl_vector_int_get(ignore, i)) continue;

	  /* get the corresponding rows of freqs, sizes, probs and probs_stddev */
	  nbr_option = gsl_vector_int_get(nbr_options, i);
	  pos = gsl_vector_int_get(items_pos, i);
	  quad_freqs_i = gsl_matrix_submatrix(quad_freqs, pos, 0, nbr_option, nbr_quad);
	  probs_i = gsl_matrix_submatrix(probs, pos, 0, nbr_option, nbr_quad);
	  probs_stddev_i = gsl_matrix_submatrix(probs_stddev, pos, 0, nbr_option, nbr_quad);
	  params.quad_freqs = &quad_freqs_i.matrix;
	  params.quad_sizes = quad_sizes;

	  /* use a root finding algorithm */
	  if (libirt_verbose > 3)
	    printf ("item %d", i + 1);

	  ret_val = mple_wave_mc (max_nr_iter, prec/10, &params, &probs_i.matrix, 
				  &probs_stddev_i.matrix, &mllk_i);

	  *nbr_notconverge += ret_val;

	  mllk += mllk_i;
	  
	  gsl_vector_int_set(notconverge, i, ret_val);
	}      

      if(gsl_isnan(mllk)) {
	if (libirt_verbose > 1) printf("NAN error ! Stopping.\n");
	break;
      }

      if(adjust_weights)
	adjust_quad_weights (nbr_subject, quad_sizes, quad_points, quad_weights);

      if (libirt_verbose > 2)
	printf("MLLK = %10.3e %%CHANGE = %9.3e\n", mllk, fabs((mllk-mllk_old)/mllk));

      /* if the change in the maximum log likelihood is small then exit */
      if (fabs((mllk-mllk_old)/mllk) < prec && em_iter >= 20) break;

      mllk_old = mllk;
    }

  /* check if the EM algo converged */
  if (em_iter <= max_em_iter && !gsl_isnan(mllk)) ret_val = 1;
  else ret_val = 0;

  if (libirt_verbose > 0 && ret_val == 0)
    printf("The EM algorithm didn't converged after %d iterations.\n", em_iter-1);

  if (libirt_verbose > 0 && ret_val == 1)
    printf("The EM algorithm converged after %d iterations.\n", em_iter);

  /* free the memory */
  gsl_matrix_free (quad_freqs);
  gsl_vector_free (quad_sizes);
  gsl_matrix_free (post);
  gsl_wavelet_free (params.wave);
  gsl_wavelet_workspace_free (params.work);
  gsl_matrix_free (params.wavelets);
  gsl_matrix_free (params.deriv_wavelets);

  return ret_val;
}
