/*
   Tests with the stochastic heat equation, demonstrating
   the benefits of Ricardson extrapolation
*/

//#include "mlmc_test.cpp"   // master MLMC file
#include "mlmc_test_100.cpp" // master file for 100 tests

#include "mlmc_rng.cpp"      // file with RNG functions
#include "mlmc_fftw.cpp"      // file with FFTW functions

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

//
// some declarations
//

int coupling; // MLMC coupling: 0 (FE), 1 (FV), 2 (spectral)
int output;   // output QoI: 0 (energy)

void heat_l(int, int, double *);

//
// main code
//

int main(int argc, char **argv) {

  output = 0;
  
  int N0   = 100;  // initial samples on each level
  int Lmin = 2;    // minimum refinement level
  int Lmax = 10;   // maximum refinement level

  int   N, L;
  float Eps[] = { 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.0 };

  char  filename[32];
  FILE *fp;

#ifdef _OPENMP
  double wtime = omp_get_wtime();
  printf(" number of threads = %d \n",omp_get_max_threads());
#endif

//
// loop over different payoffs
//

  for (coupling=0; coupling<3; coupling++) {

    N = 100000;   // samples for convergence tests
    if (coupling==1) N = 400000;
    L = 6;       // levels for convergence tests 

    // initialise generator, with separate storage for each
    // thread when compiled for OpenMP
#pragma omp parallel
    {
      rng_initialisation();
      fftw_initialisation_dirichlet(Lmax,4);
    }
    
    sprintf(filename,"heat2_%d.txt",coupling);
    fp = fopen(filename,"w");

    mlmc_test(heat_l, N,L, N0,Eps, Lmin,Lmax, fp);

    fclose(fp);

    // print out time taken, if using OpenMP
#ifdef _OPENMP
    printf(" execution time = %f s\n",omp_get_wtime() - wtime);
    wtime = omp_get_wtime();
#endif

    // delete generator and any associated storage
#pragma omp parallel
    {
      rng_termination();
      fftw_termination(Lmax);
    }

    //
    // print exact analytic value
    //

    double lam, T = 0.125, val = 0.0;

    int M = 10000;
    for (int m=1; m<M; m++) {
      lam = 3.14159265358979323846*m;
      lam = lam*lam;
      val += 0.5*(1.0-exp(-2.0*lam*T))/lam;
    }
    val += 0.5*M/lam;  // approximate truncation error
    
    printf(" true value = %f \n",val);

    //
    // now do 100 MLMC calcs
    //

    if (0) {

    // initialise generator and storage for each thread
#pragma omp parallel
    rng_initialisation();

    sprintf(filename,"heat2_%d_100.txt",coupling);
    fp = fopen(filename,"w");
    mlmc_test_100(heat_l, val, N0,Eps,Lmin,Lmax, fp);

    fclose(fp);
    
    // print out time taken, if using OpenMP
#ifdef _OPENMP
    printf(" execution time = %f s\n",omp_get_wtime() - wtime);
    wtime = omp_get_wtime();
#endif
    
    // delete generator and storage
#pragma omp parallel
    rng_termination();
    }

  }
}


/*-------------------------------------------------------
%
% level l estimator
%
*/

void heat_l(int l, int Ns, double *sums) {

  // variables declared here are shared by all OpenMP threads
  int J = 4<<l;      // number of fine grid intervals
  int N = 8<<(2*l);  // number of fine grid timesteps

  float lam = 0.25;  // dt/dx^2, fixed on all levels for stability

  float dx = 1.0/J;
  float dt = lam*dx*dx;

  // note: T = lam*N/J^2 = 0.25*8/16 = 0.125 

  for (int k=0; k<7; k++) sums[k] = 0.0;

  /*
  OpenMP reduction of C++ array sections is discussed here:
  https://www.openmp.org/spec-html/5.0/openmpsu107.html
  and an example is given on page 301 here:
  https://www.openmp.org/wp-content/uploads/openmp-examples-5.0.0.pdf
  */
  
#pragma omp parallel for shared(Ns,J,N,lam,dx,dt) reduction(+:sums[0:7])
  for (int np=0; np<Ns; np++) {
    // variables declared here inside OpenMP parallel loop
    // will have local allocation for each thread
    float um, u0, up;
    float uf[8193], uc[4097], ucc[2049], dWf[8193], dWc[4097], dWcc[2049];
    
    for (int j=0; j<=J;   j++) uf[j] = 0.0f;
    for (int j=0; j<=J/2; j++) uc[j] = 0.0f;
    for (int j=0; j<=J/4; j++) ucc[j] = 0.0f;

    dWf[0]  = 0.0f; dWf[J] = 0.0f;
    dWc[0]  = 0.0f; dWc[J/2] = 0.0f;
    dWcc[0] = 0.0f; dWcc[J/4] = 0.0f;
    
    for (int j=0; j<=J/4; j++) dWcc[j] = 0.0f;   // zero out accumulators 

    for (int n=0; n<N/4; n++) {           // outer loop over coarse timesteps

      for (int j=0; j<=J/2; j++) dWc[j] = 0.0f;   // zero out accumulators 
      
      for (int m=0; m<4; m++) {             // inner loop over fine timesteps

        // create noise terms, and add contributions to coarse accumulators

        // FE coupling
        if (coupling==0) {
          float con1 = sqrtf(dt/(6.0f*dx));
          float con2 = sqrtf(dt/(3.0f*dx));

          for (int j=0; j<J; j++) {
            float Z1 = next_normal();
            float Z2 = next_normal();

            if (j>0)   dWf[j]  += con1*Z1 + con2*Z2;
            if (j<J-1) dWf[j+1] = con1*Z1;
          }

          for (int j=1; j<J/2; j++)
            dWc[j] += 0.25f*dWf[2*j-1] + 0.5f*dWf[2*j] + 0.25f*dWf[2*j+1];

	  if (m == 3) {
            for (int j=1; j<J/4; j++)
              dWcc[j] += 0.25f*dWc[2*j-1] + 0.5f*dWc[2*j] + 0.25f*dWc[2*j+1];
	  }
        }

        // FV coupling
        else if (coupling==1) {
          float con1 = sqrtf(0.5f*dt/dx);

          for (int j=0; j<J; j++) {
            float Z1 = next_normal();
            float Z2 = next_normal();

            if (j>0)   dWf[j]  += con1*Z1;
            if (j<J-1) dWf[j+1] = con1*Z2;

            if (j>0 && j<J-1) dWc[(j+1)/2] += 0.5f*con1*(Z1+Z2);
            if (j>1 && j<J-2) dWcc[(j+2)/4] += 0.25f*con1*(Z1+Z2);
          }
        }
	
        // spectral coupling
        else if (coupling==2) {
	  // NOTE: fftw_calc multiplies sine summation by factor 2
	  // https://www.fftw.org/fftw3_doc/1d-Real_002dodd-DFTs-_0028DSTs_0029.html
	  // this is why con1 = sqrtf(0.5f*dt) rather than sqrtf(2.0f*dt)
          float con1 = sqrtf(0.5f*dt);
          float con2 = sqrtf(0.5f*dt);

          for (int j=0; j<J-1; j++) fftw_in[j] = next_normal();

          fftw_calc(l);
          for (int j=1; j<J; j++) dWf[j] = con1*fftw_out[j-1];

          if (l>0) {
            fftw_calc(l-1);
            for (int j=1; j<J/2; j++) dWc[j] += con2*fftw_out[j-1];
          }

          if (l>1) {
            fftw_calc(l-2);
            for (int j=1; j<J/4; j++) dWcc[j] += con2*fftw_out[j-1];
          }
        }
	
	// fine grid update
        u0 = uf[0]; up = uf[1];
        for (int j=1; j<J; j++) {
          um = u0; u0 = up; up = uf[j+1];
          uf[j] = u0 + lam*((up-u0)-(u0-um)) + dWf[j];
        }
      }
      
      // coarse grid update
      u0 = uc[0]; up = uc[1];
      for (int j=1; j<J/2; j++) {
        um = u0; u0 = up; up = uc[j+1];
        uc[j] = u0 + lam*((up-u0)-(u0-um)) + dWc[j];
      }

      // very coarse grid update
      if (l>1 && (n%4 == 3)) {
        u0 = ucc[0]; up = ucc[1];
        for (int j=1; j<J/4; j++) {
          um = u0; u0 = up; up = ucc[j+1];
          ucc[j] = u0 + lam*((up-u0)-(u0-um)) + dWcc[j];
	  dWcc[j] = 0.0f;
        }
      }
    }

    // compute energies

    float Pf = 0.0f;
    for (int j=1; j<J; j++) Pf += dx*uf[j]*uf[j];

    float Pc = 0.0f;
    for (int j=1; j<J/2; j++) Pc += 2.0f*dx*uc[j]*uc[j];

    float Pcc = 0.0f;
    for (int j=1; j<J/4; j++) Pcc += 4.0f*dx*ucc[j]*ucc[j];

    if (l<1) Pc  = 0.0f;
    if (l<2) Pcc = 0.0f; 

    float dP = 2.0f*(Pf-Pc) - (Pc-Pcc);
    float Pf2 = 2.0f*Pf - Pc;
    
    sums[0] += ((float) N) * ((float) J);   // add timesteps*grid_points as cost
    sums[1] += dP;
    sums[2] += dP*dP;
    sums[3] += dP*dP*dP;
    sums[4] += dP*dP*dP*dP;
    sums[5] += Pf2;
    sums[6] += Pf2*Pf2;
  }

}

