//
//----- C++11 random number generation when not using OpenMP -------------
//

#ifndef _OPENMP

#include <random>           // C++11 random number generators
#include <functional>

/* some web references

   https://www.cplusplus.com/reference/random/
   https://stackoverflow.com/questions/14023880/c11-random-numbers-and-stdbind-i
nteract-in-unexpected-way/14023935
   https://stackoverflow.com/questions/20671573/c11-stdgenerate-and-stduniform-r
eal-distribution-called-two-times-gives-st

*/

// declare generator and output distributions

std::default_random_engine rng;
std::uniform_real_distribution<float> uniform(0.0f,1.0f);
std::normal_distribution<float> normal(0.0f,1.0f);

auto next_uniform = std::bind(std::ref(uniform), std::ref(rng));
auto next_normal  = std::bind(std::ref(normal),  std::ref(rng));

void rng_initialisation() {
    rng.seed(1234);
    uniform.reset();
    normal.reset();
}

void rng_termination() {
}

//------- MKL/VSL random number generation when using OpenMP -----------

#else

#include <mkl.h>
#include <mkl_vsl.h>
#include <memory.h>
#include <omp.h>

/* each OpenMP thread has its own VSL RNG and storage */

#define NRV 16384  // number of random variables
VSLStreamStatePtr stream;
float *uniforms,      *normals;
int    uniforms_count, normals_count;
#pragma omp threadprivate(stream, uniforms,uniforms_count, \
                                  normals, normals_count)

//
// RNG routines
//

void rng_initialisation(){
  int tid = omp_get_thread_num();
  vslNewStream(&stream, VSL_BRNG_MRG32K3A,1337);
  long long skip = ((long long) (tid+1)) << 48;
  vslSkipAheadStream(stream,skip);
  uniforms     = (float *) malloc(NRV*sizeof(float));
  normals      = (float *) malloc(NRV*sizeof(float));
  uniforms_count     = 0;  // this means there are no random
  normals_count      = 0;  // numbers in the arrays currently
}

void rng_termination(){
  vslDeleteStream(&stream);
  free(uniforms);
  free(normals);
}

inline float next_uniform(){
  if (uniforms_count==0) {
    vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD,
                 stream,NRV,uniforms,0.0f,1.0f);
    uniforms_count = NRV;
  }
  return uniforms[--uniforms_count];
}

inline float next_normal(){
  if (normals_count==0) {
    vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2,
                  stream,NRV,normals,0.0f,1.0f);
    normals_count = NRV;
  }
  return normals[--normals_count];
}

#endif

//
// other header files needed for both versions
//

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

//
// main code
//

int main(int argc, char **argv)
{
  float  T=1.0f, X0=1.0f, mu=0.05f, sigma=0.2f, dt;
  double sums[2]={0.0,0.0};
  int    M = 200;      /* number of timesteps */
  int    N = 9600000;  /* total number of MC samples */

  dt  = T / ((float) M);

// initialise generator, with separate storage for each
// thread when compiled for OpenMP
#pragma omp parallel
    rng_initialisation();

#ifdef _OPENMP
  double wtime = omp_get_wtime();
#endif

#pragma omp parallel default(none) shared(X0,mu,sigma,dt,M,N) \
                                   reduction(+:sums[:2])
  {
#ifdef _OPENMP
    int tid = omp_get_thread_num();
    int nt  = omp_get_max_threads();
#else
    int tid = 0;
    int nt  = 1;
#endif

    for (int n=(N*tid)/nt; n<(N*(tid+1))/nt; n++) {
      float X = X0;

      for (int m=0; m<M; m++) {
        float Z = next_normal();
        X = X + X*(mu*dt + sigma*sqrtf(dt)*Z);
      }

      sums[0] += X;
      sums[1] += X*X;
    }
  }
  
  printf("Exact solution E[X_T] = %g\n",X0*exp(mu*T));
  printf("Monte Carlo estimate  = %g +/- %g \n",sums[0]/N,
         3.0*sqrt((sums[1]/N-(sums[0]/N)*(sums[0]/N))/N));
  printf("\nReminder: Monte Carlo estimate has discretisation bias\n\n");
  float RNGs = ((float) N)*((float) M);
  printf("Random Nums generated = %10.4g\n",RNGs);

#ifdef _OPENMP
  wtime = omp_get_wtime() - wtime;
  printf("threads               = %d\n",omp_get_max_threads());
  printf("execution time (s)    = %10.4g\n",wtime);
  printf("RNG/s                 = %10.4g\n\n",RNGs/wtime);
#endif

// delete generator and storage
#pragma omp parallel 
  rng_termination();
}
