/*

This code performs fp16 tests, first using the scalar 
_Float16 datatype and then using AVX-512 vectors

*/

#include <iostream>
#include <iomanip>
#include <cmath>
#include <cstdlib>

#include "my_mm512.h"
#include <immintrin.h>
#include <mathimf.h>

#define LUT_size 65536

bool cpu_supports_avx512dq() {
#if defined(__GNUC__) || defined(__clang__)
    return __builtin_cpu_supports("avx512dq");
#else
    return false;
#endif
}

bool cpu_supports_avx512fp16() {
#if defined(__GNUC__) || defined(__clang__)
    return __builtin_cpu_supports("avx512fp16");
#else
    return false;
#endif
}

void _mm512h_print(__m512h x) {
  __declspec(align(64)) _Float16 xarr[32];
  _mm512_store_ph(xarr, x);
  for (int k=0; k<8; k++) printf(" %f \n", (float) xarr[k]);
  printf("--------------------\n");
}

void _mm512epi16_print(__m512i i) {
  __declspec(align(64)) short iarr[32];
  _mm512_storeu_epi16(iarr, i);
  for (int k=0; k<8; k++) printf(" %d \n", (int) iarr[k]);
  printf("--------------------\n");
}


//
// Normal pdf function, not defined in mathimf.h
//

double pdfnorm(double x) {
  return exp(-0.5*x*x)/sqrt(2.0*M_PI);
}

//
// super-dyadic piecewise linear approximation to Normal distribution
//

_Float16 normal_dyadic2(ushort i) {
  _Float16 con1[32] = {
    0.0000000f16,   0.0000000f16,   0.0000000f16,   0.0000000f16, 
    0.0000000f16,   6.4032523f16,   4.6524331f16,   3.3429773f16, 
    2.4320530f16,   1.7896978f16,   1.2990414f16,   0.9409888f16, 
    0.6846983f16,   0.5001156f16,   0.3646161f16,   0.2660284f16, 
    0.1946923f16,   0.1428579f16,   0.1050341f16,   0.0774698f16, 
    0.0573213f16,   0.0425913f16,   0.0317966f16,   0.0238803f16, 
    0.0180653f16,   0.0137943f16,   0.0106610f16,   0.0083761f16, 
    0.0067380f16,   0.0056217f16,   0.0049986f16,   0.0000000f16
  };
  
  _Float16 con2[32] = {
   -4.3875114f16,  -4.0804306f16,   0.0000000f16,  -3.9572776f16, 
   -3.8753599f16,  -4.0135121f16,  -3.9394078f16,  -3.8606159f16, 
   -3.7800080f16,  -3.7008669f16,  -3.6146951f16,  -3.5256808f16, 
   -3.4346582f16,  -3.3423427f16,  -3.2462729f16,  -3.1474375f16, 
   -3.0459985f16,  -2.9419128f16,  -2.8343864f16,  -2.7236453f16, 
   -2.6091068f16,  -2.4907230f16,  -2.3679991f16,  -2.2407355f16, 
   -2.1085159f16,  -1.9711855f16,  -1.8287113f16,  -1.6818128f16, 
   -1.5329659f16,  -1.3898027f16,  -1.2778941f16,   0.0000000f16
  };

  union S
  {
    _Float16 y;
    short   j;
  };
  S y;

  ushort j;
  _Float16 x;
  bool m;
  
  m = (i >= (1<<15));
  if (m) i = ((1<<16)-1) - i;

  x   = 0.0078125f16 * ((_Float16) i);  //  0.007812 = 2^(-7)
  y.y = x*x;
  j   = y.j >> 10;

  x = con1[j]*x + con2[j];
  if (m) x = -x;

  return x;
}


int main() {

  printf("\nFirst, scalar tests using _Float16 datatype");
  printf("\n===========================================\n");
  
  //
  // construct uniform LUT table, and print comparison of
  // uniform, super-dyadic, and "true" inverse CDF values
  // for first few mid-points
  
  __declspec(align(64)) _Float16 LUT_uniform[LUT_size];

  double du = 1.0 / LUT_size;
  double err=0.0;

  for (int i=0; i<LUT_size; i++) {
    double um = fmax( i   *du, 1e-15);
    double up = fmin((i+1)*du, 1.0-1e-15);
    LUT_uniform[i] = ( pdfnorm(cdfnorminv(um))
                     - pdfnorm(cdfnorminv(up)) ) / du;

    double diff = LUT_uniform[i] - normal_dyadic2((ushort) i);
    err += diff*diff;
  }

  printf("MSE between LUT and superdyadic: %g \n",err/LUT_size);

  //
  // now determine MSE of uniform LUT and super-dyadic approximations
  //

  int    IMAX = 1<<24;
  double err1=0.0, err2=0.0; 

  for (int i=0; i<IMAX; i++) {
    ushort   j = i>>8;

    double x = cdfnorminv((i+0.5)/IMAX);
    _Float16 x_LUT = LUT_uniform[j];
    _Float16 x_dy2 = normal_dyadic2(j);

    err1 += (x-x_LUT)*(x-x_LUT);
    err2 += (x-x_dy2)*(x-x_dy2);

    // if (i<8) printf(" %f \n",x);
  }

  printf("Uniform LUT MSE = %g \n", err1/IMAX);
  printf("Superdyadic MSE = %g \n", err2/IMAX);

//
//------------------------------------------------------------
//
  
  printf("\nNext, vector tests using __m512h datatype");
  printf("\n=========================================\n");

  if (!cpu_supports_avx512fp16() || !cpu_supports_avx512dq()) {
    std::cerr << "Error: AVX-512 FP16 or DQ not supported on this system\n";
    return EXIT_FAILURE;
  }

  __declspec(align(64)) _Float16 LUT_double[2*LUT_size];
  
  for (int i=0; i<LUT_size; i++) {
    LUT_double[2*i  ] = LUT_uniform[i];
    LUT_double[2*i+1] = LUT_uniform[i];
  }
  
  __declspec(align(64)) ushort   iarr_h[32];
  __declspec(align(64)) _Float16 xarr_h[32];
  __declspec(align(64)) _Float16 xarr_hb[32];

  __m512i i_32;
  __m512h x_32;
  
  err = 0.0;

  for (int i=0; i<LUT_size; i+=32) {
    for (int k=0; k<32; k++) iarr_h[k] = i+k;
    i_32 = _mm512_loadu_epi16(iarr_h);

    x_32 = gather_fp16(LUT_double,i_32);
    aligned_storeh(xarr_h, x_32);

    x_32 = norminv_fp16(i_32);
    aligned_storeh(xarr_hb, x_32);

    for (int k=0; k<32; k++) {
      double diff =  xarr_h[k] - xarr_hb[k];
      err += diff*diff;
    }
  }

  printf("MSE between LUT and superdyadic: %g \n",err/LUT_size);

  //
  // now determine MSE of uniform LUT and super-dyadic approximations
  //

  err1=0.0, err2=0.0;

  for (int i=0; i<IMAX; i+=32) {
    for (int k=0; k<32; k++) iarr_h[k] = (i+k)>>8;
    i_32 = _mm512_loadu_epi16(iarr_h);

    x_32 = gather_fp16(LUT_double,i_32);
    aligned_storeh(xarr_h, x_32);

    x_32 = norminv_fp16(i_32);
    aligned_storeh(xarr_hb, x_32);

    for (int k=0; k<32; k++) {
      double x = cdfnorminv((i+k+0.5)/IMAX);
      double diff;
      
      diff  = xarr_h[k] - x;
      err1 += diff*diff;
      diff  = xarr_hb[k] - x;
      err2 += diff*diff;
    }
  }

  printf("Uniform LUT MSE = %g \n", err1/IMAX);
  printf("Superdyadic MSE = %g \n", err2/IMAX); 
}
  
