/*
   Tests with the stochastic heat equation, trying different
   couplings and different output quantities of interest
*/

//#include "mlmc_test.cpp"   // master MLMC file
#include "mlmc_test_100.cpp" // master file for 100 tests

#include "mlmc_rng.cpp"      // file with RNG functions
#include "mlmc_fftw.cpp"      // file with FFTW functions

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

//
// some declarations
//

int coupling; // MLMC coupling: 0 (FE), 1 (FV), 2 (spectral)
int output;   // output QoI: 0 (energy), 1 (linear functional variance)

void heat_l(int, int, double *);

//
// main code
//

int main(int argc, char **argv) {

  int N0   = 50 ;  // initial samples on each level
  int Lmin = 2;    // minimum refinement level
  int Lmax = 10;   // maximum refinement level
  
  int   N, L;
  float Eps[11];
  char  filename[32];
  FILE *fp;

#ifdef _OPENMP
  double wtime = omp_get_wtime();
  printf(" number of threads = %d \n",omp_get_max_threads());
#endif

//
// loop over different outputs and different couplings
//

  for (output=0; output<2; output++) {
    for (coupling=0; coupling<3; coupling++) {

      if (output==0) {
        N = 100000;  // samples for convergence tests
	if (coupling==1) N = 400000;
        L = 6;       // levels for convergence tests 
        float Eps2[] = { 0.0005, 0.001, 0.002, 0.005, 0.01, 0.0 };
        memcpy(Eps,Eps2,sizeof(Eps2));
      }
      else {
        N = 100000;  // samples for convergence tests
	if (coupling==1) N = 400000;
        L = 6;       // levels for convergence tests 
        float Eps2[] = { 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.0 };
        memcpy(Eps,Eps2,sizeof(Eps2));
      }

      // initialise generator, with separate storage for each
      // thread when compiled for OpenMP
#pragma omp parallel
      {
        rng_initialisation();
        fftw_initialisation_dirichlet(Lmax,4);
      }
    
      sprintf(filename,"heat_%d_%d.txt",output,coupling);
      fp = fopen(filename,"w");

      mlmc_test(heat_l, N,L, N0,Eps, Lmin,Lmax, fp);

      fclose(fp);

      // print out time taken, if using OpenMP
#ifdef _OPENMP
      printf(" execution time = %f s\n",omp_get_wtime() - wtime);
      wtime = omp_get_wtime();
#endif

      // delete generator and any associated storage
#pragma omp parallel
      {
        rng_termination();
        fftw_termination(Lmax);
      }

      //
      // print exact analytic value
      //

      double lam, sig2, T=0.125, val=0.0;

      int M = 10000;
      for (int m=1; m<M; m++) {
        lam = 3.14159265358979323846*m;
        lam = lam*lam;
	sig2 = 0.5*(1.0-exp(-2.0*lam*T))/lam;
	if (output==0)
          val += sig2;
	else if(m%2==1)
	  val += sig2*96.0f/(lam*lam);
      }

      if (output==0)
        val += 0.5*M/lam;  // approximate truncation error
    
      printf(" true value = %f \n",val);

      //
      // now do 100 MLMC calcs
      //

      if (0) {

      // initialise generator and storage for each thread
#pragma omp parallel
      rng_initialisation();

      sprintf(filename,"heat_%d_%d_100.txt",output,coupling);
      fp = fopen(filename,"w");
      mlmc_test_100(heat_l, val, N0,Eps,Lmin,Lmax, fp);

      fclose(fp);
    
      // print out time taken, if using OpenMP
#ifdef _OPENMP
      printf(" execution time = %f s\n",omp_get_wtime() - wtime);
      wtime = omp_get_wtime();
#endif
    
      // delete generator and storage
#pragma omp parallel
      rng_termination();
      }
    }
  }
}


/*-------------------------------------------------------
%
% level l estimator
%
*/

void heat_l(int l, int Ns, double *sums) {

  // variables declared here are shared by all OpenMP threads
  int J = 4<<l;      // number of fine grid intervals
  int N = 8<<(2*l);  // number of fine grid timesteps

  float lam = 0.25;  // dt/dx^2, fixed on all levels for stability

  float dx = 1.0/J;
  float dt = lam*dx*dx;

  // note: T = lam*N/J^2 = 0.25*8/16 = 0.125 

  for (int k=0; k<7; k++) sums[k] = 0.0;

  /*
  OpenMP reduction of C++ array sections is discussed here:
  https://www.openmp.org/spec-html/5.0/openmpsu107.html
  and an example is given on page 301 here:
  https://www.openmp.org/wp-content/uploads/openmp-examples-5.0.0.pdf
  */
  
#pragma omp parallel for shared(Ns,J,N,lam,dx,dt) reduction(+:sums[0:7])
  for (int np=0; np<Ns; np++) {
    // variables declared here inside OpenMP parallel loop
    // will have local allocation for each thread
    float um, u0, up;
    float uf[8193], uc[4097], dWf[8193], dWc[4097];
    
    for (int j=0; j<=J;   j++) uf[j] = 0.0f;
    for (int j=0; j<=J/2; j++) uc[j] = 0.0f;

    dWf[0] = 0.0f; dWf[J] = 0.0f;
    dWc[0] = 0.0f; dWc[J/2] = 0.0f;
    
    for (int n=0; n<N/4; n++) {                   // outer loop over coarse timesteps

      for (int j=0; j<=J/2; j++) dWc[j] = 0.0f;   // zero out accumulators 
      
      for (int m=0; m<4; m++) {                   // inner loop over fine timesteps

        // create noise terms, and add contributions to coarse accumulators

        // FE coupling
        if (coupling==0) {
          float con1 = sqrtf(dt/(6.0f*dx));
          float con2 = sqrtf(dt/(3.0f*dx));
  
          for (int j=0; j<J; j++) {
            float Z1 = next_normal();
            float Z2 = next_normal();

            if (j>0)   dWf[j]  += con1*Z1 + con2*Z2;
            if (j<J-1) dWf[j+1] = con1*Z1;
          }

          for (int j=1; j<J/2; j++)
            dWc[j] += 0.25f*dWf[2*j-1] + 0.5f*dWf[2*j] + 0.25f*dWf[2*j+1];
        }

        // FV coupling
        else if (coupling==1) {
          float con1 = sqrtf(0.5f*dt/dx);

          for (int j=0; j<J; j++) {
            float Z1 = next_normal();
            float Z2 = next_normal();

            if (j>0)   dWf[j]  += con1*Z1;
            if (j<J-1) dWf[j+1] = con1*Z2;

            if (j>0 && j<J-1) dWc[(j+1)/2] += 0.5f*con1*(Z1+Z2);
          }
	}
	
        // spectral coupling
        else if (coupling==2) {
	  // NOTE: fftw_calc multiplies sine summation by factor 2
	  // https://www.fftw.org/fftw3_doc/1d-Real_002dodd-DFTs-_0028DSTs_0029.html
	  // this is why con1 = sqrtf(0.5f*dt) rather than sqrtf(2.0f*dt)
          float con1 = sqrtf(0.5f*dt);
          float con2 = sqrtf(0.5f*dt);

          for (int j=0; j<J-1; j++) fftw_in[j] = next_normal();

          fftw_calc(l);
          for (int j=1; j<J; j++) dWf[j] = con1*fftw_out[j-1];

          if (l>0) {
            fftw_calc(l-1);
            for (int j=1; j<J/2; j++) dWc[j] += con2*fftw_out[j-1];
          }
        }
	
	// fine grid update
        u0 = uf[0]; up = uf[1];
        for (int j=1; j<J; j++) {
          um = u0; u0 = up; up = uf[j+1];
          uf[j] = u0 + lam*((up-u0)-(u0-um)) + dWf[j];
        }
      }
      
      // coarse grid update
      u0 = uc[0]; up = uc[1];
      for (int j=1; j<J/2; j++) {
        um = u0; u0 = up; up = uc[j+1];
        uc[j] = u0 + lam*((up-u0)-(u0-um)) + dWc[j];
      }
    }

    // compute energies

    float Pc=0.0f, Pf=0.0f;
    
    if (output==0) {
      for (int j=1; j<J; j++)   Pf += dx*uf[j]*uf[j];
      for (int j=1; j<J/2; j++) Pc += 2.0f*dx*uc[j]*uc[j];
    }
    else if (output==1) {
      // factor 12 to make ||phi||=1
      for (int j=1; j<J; j++)	Pf += dx*uf[j]*fminf(j*dx,1.0f-j*dx);
      Pf = 12.0f*Pf*Pf;
      for (int j=1; j<J/2; j++)	Pc += 2.0f*dx*uc[j]*fminf(2*j*dx,1.0f-2*j*dx);
      Pc = 12.0f*Pc*Pc;
    }

    if (l==0) Pc = 0.0f;

    float dP = Pf - Pc;

    sums[0] += ((float) N) * ((float) J);   // add timesteps*grid_points as cost
    sums[1] += dP;
    sums[2] += dP*dP;
    sums[3] += dP*dP*dP;
    sums[4] += dP*dP*dP*dP;
    sums[5] += Pf;
    sums[6] += Pf*Pf;
  }

}

