/*
 * Example SDE code, implementing the Euler-Maruyama scheme for 
 * Geometric Brownian Motion using MKL/VSL random number generation
 * and OpenMP vectorisation and parallelisation
 *
 * author: Mike Giles, based on previous code written by 
 *         David J. Warne
 *         School of Mathematical Sciences 
 *         Queensland University of Technology
 *
 * Date: Nov 23 2016
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <mkl.h>
#include <mkl_vsl.h>
#include <memory.h>
#include <omp.h>

// vector length for vectorisation
#define VECTOR_LENGTH 4   

// macros for rounding x up or down to a multiple of y
#define ROUND_UP(x, y) ( ( ((x) + (y) - 1) / (y) ) * (y) )
#define ROUND_DOWN(x, y) ( ((x) / (y)) * (y) )

// each OpenMP thread has its own VSL RNG and storage
#define RV_BYTES 65536
double           *dW;
VSLStreamStatePtr stream;
#pragma omp threadprivate(stream, dW)

void pathcalc(double,double,double,double,double, int,int,
              double *,double *);

int main(int argc, char **argv)
{
  double T=1.0, X0=1.0, mu=0.05, sigma=0.2, sum1=0.0, sum2=0.0, dt;
  int    M  = 20;       /* number of timesteps */
  int    N  = 51200000; /* total number of MC samples */

  dt  = T / ((double) M);
  
  printf("#OpenMP threads = %d\n\n",omp_get_max_threads());
  
// parallel initialisation for each thread
#pragma omp parallel
  {
    // create persistent RNG, with a unique skipahead for each thread
    vslNewStream(&stream, VSL_BRNG_MRG32K3A,1337);
    int       tid    = omp_get_thread_num();
    long long skip = ((long long) (tid+1)) << 48;
    vslSkipAheadStream(stream,skip);

    // allocate memory
    dW = (double *)malloc(RV_BYTES);
  }
  
// now do path calculations in parallel
  double start = omp_get_wtime();

#pragma omp parallel shared(T,X0,mu,sigma,dt,M,N)	\
                     reduction(+:sum1,sum2)
  {
    double sum1_t = 0.0, sum2_t = 0.0;
    int    num_t  = omp_get_num_threads();
    int    tid    = omp_get_thread_num();

    int N3 = ROUND_UP(((tid+1)*((long long) N))/num_t,VECTOR_LENGTH)
           - ROUND_UP(( tid   *((long long) N))/num_t,VECTOR_LENGTH);

    pathcalc(T,X0,mu,sigma,dt, M, N3, &sum1_t, &sum2_t);
    sum1 += sum1_t;
    sum2 += sum2_t;
  }
  
  double elapsed = omp_get_wtime()-start;
  printf("Total elapsed time = %g secs\n",elapsed);
  double normals = M*N;
  printf("Total Normals generated = %g at %g/sec\n\n",normals,normals/elapsed);
  
  printf("Exact solution E[X_T] = %g\n",X0*exp(mu*T));
  printf("Monte Carlo estimate  = %g +/- %g \n",sum1/N,
         3.0*sqrt((sum2/N-(sum1/N)*(sum1/N))/N));
  printf("Reminder: Monte Carlo estimate has discretisation bias\n");

// delete generator and storage for each thread
#pragma omp parallel 
  {
    vslDeleteStream(&stream);
    free(dW);
  }
}

void pathcalc(double T,double X0,double mu,double sigma,double dt,
              int M, int N, double *sum1_t, double *sum2_t) {

  double sum1=0.0, sum2=0.0;
  double start=omp_get_wtime(), stop, gen_time=0.0, path_time=0.0;
  double normals=0.0, paths=0.0;

  /* work out max number of paths in a group */
  int bytes_per_path = M*sizeof(double);
  int N2 = ROUND_DOWN(RV_BYTES/bytes_per_path,VECTOR_LENGTH);

  /* loop over all paths in groups of size N2 */
  for (int n0=0; n0<N; n0+=N2) {
    /* may have to reduce size of final group */
    if (N2>N-n0) N2 = N-n0;

    /* generate required random numbers for this group */
    vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2,
                  stream,M*N2,dW,0,sqrt(dt));

    normals += M*N2;
    stop = omp_get_wtime();
    gen_time += stop-start;
    start = stop;

#pragma omp simd reduction(+:sum1,sum2) simdlen(32)
    for (int n2=0; n2<N2; n2++) {
      double X = X0;
	
      for (int m=0; m<M; m++) {
        double delW  = dW[m*N2+n2]; 
        X = X*(1.0 + mu*dt + sigma*delW);
      }

      sum1 += X;
      sum2 += X*X;
    }
    
    paths += N2;
    stop = omp_get_wtime();
    path_time += stop-start;
    start = stop;
  }

  *sum1_t = sum1;
  *sum2_t = sum2;

  if(omp_get_thread_num()==0) {
    printf("Thread 0 generated %g Normals in %g sec, at %g/sec\n",
	   normals,gen_time,normals/gen_time);
    printf("Thread 0 calculated %g paths in %g sec, at %g GB/s\n\n",
	   paths,path_time,8.0*normals/path_time/1e9);
  }
}
