/*

  mlmc_test(mlmc_l, N,L, N0,Eps,Lmin,Lmax, fp)

  multilevel Monte Carlo test routine

   mlmc_l(Nl,sums)     low-level routine
   inputs: Nl = minimum required paths on each level (terminated by 0)
   output: sums[8*l+0] = sum(1)
           sums[8*l+1] = sum(cost)
           sums[8*l+2] = sum(Pf-Pc)
           sums[8*l+3] = sum((Pf-Pc).^2)
           sums[8*l+4] = sum((Pf-Pc).^3)
           sums[8*l+5] = sum((Pf-Pc).^4)
           sums[8*l+6] = sum(Pf)
           sums[8*l+7] = sum(Pf.^2)

   N      = number of samples for convergence tests
   L      = number of levels for convergence tests

   N0     = initial number of samples
   Eps    = desired accuracy array (terminated by 0)
   Lmin   = minimum level of refinement
   Lmax   = maximum level of refinement

   fp     = file handle for output
*/

#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <time.h>
#include <string.h>

#include <mlmc.cpp>

// https://gcc.gnu.org/onlinedocs/cpp/Variadic-Macros.html
// variadic macro to print to both file and stdout
#define PRINTF2(fp, ...) {printf(__VA_ARGS__);fprintf(fp,__VA_ARGS__);}


void mlmc_test(void (*mlmc_l)(long long *, double *), int N,int L,
               int N0, float *Eps, int Lmin, int Lmax, FILE *fp) {

//
// first, convergence tests
//

  // current date/time based on current system
  time_t now = time(NULL);
  char *date = ctime(&now);
  int len = strlen(date);
  date[len-1] = ' ';

  PRINTF2(fp,"\n");
  PRINTF2(fp,"**********************************************************\n");
  PRINTF2(fp,"*** MLMC file version 1.0     produced by              ***\n");
  PRINTF2(fp,"*** C++/CUDA mlmc_test on %s    ***\n",date);
  PRINTF2(fp,"**********************************************************\n");
  PRINTF2(fp,"\n");
  PRINTF2(fp,"**********************************************************\n");
  PRINTF2(fp,"*** Convergence tests, kurtosis, telescoping sum check ***\n");
  PRINTF2(fp,"*** using N =%7d samples                           ***\n",N);
  PRINTF2(fp,"**********************************************************\n");
  PRINTF2(fp,"\n l   ave(Pf-Pc)    ave(Pf)   var(Pf-Pc)    var(Pf)");
  PRINTF2(fp,"    kurtosis     check        cost \n--------------------------");
  PRINTF2(fp,"-------------------------------------------------------------\n");

  double    *sums = (double *)    malloc(21*8*sizeof(double));
  long long *Nl   = (long long *) malloc(22*sizeof(long long));
  float     *Cl   = (float *)     malloc(21*sizeof(float));

  float *sum  = (float *) malloc((L+1)*sizeof(float));
  float *cost = (float *) malloc((L+1)*sizeof(float));
  float *del1 = (float *) malloc((L+1)*sizeof(float));
  float *del2 = (float *) malloc((L+1)*sizeof(float));
  float *var1 = (float *) malloc((L+1)*sizeof(float));
  float *var2 = (float *) malloc((L+1)*sizeof(float));
  float *chk1 = (float *) malloc((L+1)*sizeof(float));
  float *kur1 = (float *) malloc((L+1)*sizeof(float));

  for (int l=0; l<=L; l++) {
    Nl[l] = N;
    for (int m=0; m<8; m++)
      sums[8*l+m] = 0.0;
  }
  Nl[L+1] = 0;

  // call mlmc_l, with second call to terminate CUDA kernels
  
  mlmc_l(Nl,sums);
  // printf("*** finished calc, shutting down kernels ***\n");
  Nl[0]=0;
  mlmc_l(Nl,sums);

  // print out various bits of information
  
  for (int l=0; l<=L; l++) {
    for (int m=0; m<7; m++) sum[m] = sums[8*l+m+1]/sums[8*l];

    cost[l] = sum[0];
    del1[l] = sum[1];
    del2[l] = sum[5];
    var1[l] = fmax(sum[2]-sum[1]*sum[1], 1e-10);
    var2[l] = fmax(sum[6]-sum[5]*sum[5], 1e-10);

    kur1[l]  = (      sum[4]
                - 4.0*sum[3]*sum[1]
                + 6.0*sum[2]*sum[1]*sum[1]
                - 3.0*sum[1]*sum[1]*sum[1]*sum[1] )
             / (var1[l]*var1[l]);

    if (l==0)
      chk1[l] = 0.0f;
    else
      chk1[l] = sqrtf((float) N) * 
                fabsf(  del1[l]  +       del2[l-1]  -       del2[l] )
         / (3.0f*(sqrtf(var1[l]) + sqrtf(var2[l-1]) + sqrtf(var2[l])));

    PRINTF2(fp,"%2d  %11.4e %11.4e %11.4e %11.4e %11.4e %11.4e %11.4e \n",
    l,del1[l],del2[l],var1[l],var2[l],kur1[l],chk1[l],cost[l]);
  }

//
// print out a warning if kurtosis or consistency check looks bad
//

  if (kur1[L] > 100.0f) {
    PRINTF2(fp,"\n WARNING: kurtosis on finest level = %f \n",kur1[L]);
    PRINTF2(fp," MLMC correction dominated by a few rare paths; \n");
  }

  float max_chk = 0.0f;
  for (int l=0; l<=L; l++) max_chk = fmaxf(max_chk,chk1[l]);
  if (max_chk > 1.0f) {
    PRINTF2(fp,"\n WARNING: maximum consistency error = %f \n",max_chk);
    PRINTF2(fp," identity E[Pf-Pc] = E[Pf] - E[Pc] not satisfied ? \n");
  }

//
// use linear regression to estimate alpha, beta, gamma
//

  float alpha, beta, gamma, foo;
  float *x = (float *) malloc(L*sizeof(float));
  float *y = (float *) malloc(L*sizeof(float));

  //  printf("finished second round of malloc's \n");
  
  for (int l=1; l<=L; l++) {
    x[l-1] = l;
    y[l-1] = - log2f(fabsf(del1[l]));
  } 
  regression(L,x,y,alpha,foo);

  for (int l=1; l<=L; l++) {
    x[l-1] = l;
    y[l-1] = - log2f(var1[l]);
  } 
  regression(L,x,y,beta,foo);

  for (int l=1; l<=L; l++) {
    x[l-1] = l;
    y[l-1] = log2f(cost[l]);
  } 
  regression(L,x,y,gamma,foo);

  PRINTF2(fp,"\n******************************************************\n");
  PRINTF2(fp,"*** Linear regression estimates of MLMC parameters ***\n");
  PRINTF2(fp,"******************************************************\n");
  PRINTF2(fp,"\n alpha = %f  (exponent for MLMC weak convergence)\n",alpha);
  PRINTF2(fp," beta  = %f  (exponent for MLMC variance) \n",beta);
  PRINTF2(fp," gamma = %f  (exponent for MLMC cost) \n",gamma);

  //  return;

//
// second, mlmc complexity tests
//

  PRINTF2(fp,"\n");
  PRINTF2(fp,"***************************** \n");
  PRINTF2(fp,"*** MLMC complexity tests *** \n");
  PRINTF2(fp,"***************************** \n\n");
  PRINTF2(fp,"  eps       value   mlmc_cost   std_cost  savings     N_l \n");
  PRINTF2(fp,"--------------------------------------------------------- \n");
 
  int i=0;

  while (Eps[i]>0) {
    float eps = Eps[i++];
    float P = mlmc(Lmin,Lmax,N0,eps,mlmc_l, alpha,beta,gamma, Nl,Cl);
    float std_cost = 0.0f, mlmc_cost = 0.0f, theta=0.25f;

    for (int l=0; Nl[l]>0; l++) {
      //      printf("l=%d, Nl=%lld, Cl=%f \n",l,Nl[l],Cl[l]);
      mlmc_cost += Nl[l]*Cl[l];
      L = l;
    }
    // printf("L=%d, var2[L]=%f, Cl[L]=%f \n",L,var2[L],Cl[L]);
    std_cost = var2[L]*Cl[L] / ((1.0f-theta)*eps*eps);


    PRINTF2(fp,"%.4f  %.4e  %.3e  %.3e  %7.2f ",
	    eps, P, mlmc_cost, std_cost, std_cost/mlmc_cost);
    for (int l=0; Nl[l]>0 && l<=Lmax; l++) PRINTF2(fp,"%9lld",Nl[l]);
    PRINTF2(fp,"\n");
  }
  PRINTF2(fp,"\n");
}
