////////////////////////////////////////////////////////////////////////
// Performance assessment of Intel's MKL/VSL library compared to
// my generation of 16-bit Normals based on 16-bit random integers
////////////////////////////////////////////////////////////////////////

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#include <mkl.h>
#include <mkl_vsl.h>
#include <memory.h>
#include <immintrin.h>
#include <mathimf.h>
#include <omp.h>

#include "my_mm512.h"

//
// Normal pdf function, not defined in mathimf.h
//

double pdfnorm(double x) {
  return exp(-0.5*x*x)/sqrt(2.0*M_PI);
}

/* each OpenMP thread has its own VSL RNG and storage */

#define NRV 1024  // number of random AVX512 vectors
//#define NRV 256  // number of random AVX512 vectors
VSLStreamStatePtr stream;
__m512i *uniformbits;
__m512  *normals;
__m512h *normals_fp16;
int      uniformbits_count, normals_count, normals_fp16_count;
#pragma omp threadprivate(stream, uniformbits,  uniformbits_count, \
          normals, normals_count, normals_fp16, normals_fp16_count)

//
// RNG routines
//

void rng_initialisation(){
  int tid = omp_get_thread_num();
  vslNewStream(&stream, VSL_BRNG_MT19937,1337);
  long long skip = ((long long) (tid+1)) << 48;
  vslSkipAheadStream(stream,skip);
  uniformbits  = (__m512i *) malloc(NRV*sizeof(__m512i));
  normals      = (__m512  *) malloc(NRV*sizeof(__m512));
  normals_fp16 = (__m512h *) malloc(NRV*sizeof(__m512h));
  uniformbits_count  = 0;  // this means there are no random
  normals_count      = 0;  // numbers in the arrays currently
  normals_fp16_count = 0;
}

void rng_termination(){
  vslDeleteStream(&stream);
  free(uniformbits);
  free(normals);
}

inline __m512i next_uniformbits(){
  if (uniformbits_count==0) {
    viRngUniformBits32(VSL_RNG_METHOD_UNIFORMBITS32_STD,
                       stream,16*NRV,(uint *)uniformbits);
    uniformbits_count = NRV;
  }
  return uniformbits[--uniformbits_count];
}

inline __m512 next_normal(){
  if (normals_count==0) {
    vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF,
                  stream,16*NRV,(float *)normals,0.0f,1.0f);
    normals_count = NRV;
  }
  return normals[--normals_count];
}

inline __m512 next_normal_fp16(){
  if (normals_fp16_count==0) {
    viRngUniformBits32(VSL_RNG_METHOD_UNIFORMBITS32_STD,
                       stream,16*NRV,(uint *)normals_fp16);
#pragma omp unroll partial(2)
    for (int n=0; n<NRV; n++)
      normals_fp16[n] = norminv_fp16((__m512i) normals_fp16[n]);

    normals_fp16_count = NRV;
  }
  return normals_fp16[--normals_fp16_count];
}

//
// main code
//

int main(int argc, char **argv)
{
  double    wtime=0.0;
  long long N = ((long long) 1)<<31; /* number of AVX vectors sampled */

// initialise generator, with separate storage for each thread

#pragma omp parallel
  rng_initialisation();

// time execution of loop which generates Normals

  __m512 sum2 = _mm512_setzero_ps();
  float  arr2[16];

  wtime = omp_get_wtime();

#pragma omp parallel for default(none) shared(N) reduction(+:sum2)
  for (long long n=0; n<N; n++) {
    sum2 = sum2 + next_normal();
  }
    
  wtime = omp_get_wtime() - wtime;

  _mm512_store_ph(arr2, sum2);
  // printf("arr = %g %g %g %g \n",arr2[0],arr2[1],arr2[2],arr2[3]);

  // print out results
  
  printf("fp32 Normal RNG tests:\n");
  printf("max threads         = %10d\n",omp_get_max_threads());
  printf("execution time (ms) = %10.4g\n",1000.0*wtime);
  printf("32bit RNG/s         = %10.4g\n\n",16.0*N/wtime);

// time execution of loop which generates random bits

  __m512i isum = _mm512_setzero_epi32();
  int     iarr[16];

  wtime = omp_get_wtime();

#pragma omp parallel for default(none) shared(N) reduction(+:isum)
  for (long long n=0; n<N; n++) {
    isum = isum + next_uniformbits();
  }
    
  wtime = omp_get_wtime() - wtime;

  _mm512_store_epi32(iarr, isum);
  // printf("arr = %d %d %d %d \n",iarr[0],iarr[1],iarr[2],iarr[3]);

  // print out results
  
  printf("UniformBits32 RNG tests:\n");
  printf("max threads         = %10d\n",omp_get_max_threads());
  printf("execution time (ms) = %10.4g\n",1000.0*wtime);
  // 16 32-bit elements per vector
  printf("32bit RNG/s         = %10.4g\n",16.0*N/wtime);
  // 2 for write+read, 512-bit=64 bytes
  printf("bandwidth (GB/s)    = %10.4g\n\n",2.0*64.0*N/wtime/1e9);

// time execution of loop which generates random bits
// and transforms them into fp16 Normals using superdyadic approx

  __m512h  sum;
  _Float16 arr[32];

  sum   = _mm512_setzero_ph();
  
  wtime = omp_get_wtime();
  
#pragma omp parallel for default(none) shared(N) reduction(+:sum)
  for (long long n=0; n<N; n++) {
    sum = sum + norminv_fp16(next_uniformbits());
  }
    
  wtime = omp_get_wtime() - wtime;

  aligned_storeh(arr,sum);
  // printf("arr = %10g %10g %10g %10g \n", (float) arr[0],
  //        (float) arr[1], (float) arr[2], (float) arr[3]);

  // print out results
  
  printf("UniformBits32/norminv_fp16 tests:\n");
  printf("max threads         = %10d\n",omp_get_max_threads());
  printf("execution time (ms) = %10.4g\n",1000.0*wtime);
  // 32 16-bit elements per vector
  printf("16bit RNG/s         = %10.4g\n\n",32.0*N/wtime);

// time execution of alternative loop which generates random bits
// and transforms them into fp16 Normals using superdyadic approx

  sum   = _mm512_setzero_ph();
  
  wtime = omp_get_wtime();
  
#pragma omp parallel for default(none) shared(N) reduction(+:sum)
  for (long long n=0; n<N; n++) {
    sum = sum + 0.25*next_normal_fp16();
  }
    
  wtime = omp_get_wtime() - wtime;

  aligned_storeh(arr,sum);
  // printf("arr = %10g %10g %10g %10g \n", (float) arr[0],
  //        (float) arr[1], (float) arr[2], (float) arr[3]);

  // print out results
  
  printf("fp16 Normal RNG tests:\n");
  printf("max threads         = %10d\n",omp_get_max_threads());
  printf("execution time (ms) = %10.4g\n",1000.0*wtime);
  // 32 16-bit elements per vector
  printf("16bit RNG/s         = %10.4g\n\n",32.0*N/wtime);

// time execution of loop which generates random bits
// and transforms them into fp16 Normals using LUT

  __declspec(align(64)) _Float16 LUT[2*65536];
  
  for (int i=0; i<65536; i++) {
    double um = fmax( i   /65536.0, 1e-15);
    double up = fmin((i+1)/65536.0, 1.0-1e-15);
    double x  = 65536.0*( pdfnorm(cdfnorminv(um))
                        - pdfnorm(cdfnorminv(up)) );
    LUT[2*i  ] = x;
    LUT[2*i+1] = x;
  }

  sum   = _mm512_setzero_ph();
  wtime = omp_get_wtime();

#pragma omp parallel for default(none) shared(N,LUT) reduction(+:sum)
  for (long long n=0; n<N; n++) {
    sum = sum + gather_fp16(LUT,next_uniformbits());
  }

  wtime = omp_get_wtime() - wtime;
  
  aligned_storeh(arr,sum);
  // printf("arr = %10.4g %10.4g %10.4g %10.4g \n", (float) arr[0],
  //                (float) arr[1], (float) arr[2], (float) arr[3]);

  // print out results
  
  printf("UniformBits32/gather_fp16 tests:\n");
  printf("max threads         = %10d\n",omp_get_max_threads());
  printf("execution time (ms) = %10.4g\n",1000.0*wtime);
  // 32 16-bit elements per vector
  printf("16bit RNG/s         = %10.4g\n\n",32.0*N/wtime);

// delete generator and storage
#pragma omp parallel 
  rng_termination();
}
