//
// header file to simplify use of FFTW routines in MLMC
//

// Notes:
//
// 1) FFTW_RODFT00 multiplies sine summation by factor 2
//    https://www.fftw.org/fftw3_doc/1d-Real_002dodd-DFTs-_0028DSTs_0029.html
//
// 2) FFTW_HC2R does an FFT with Hermitian input packed as reals, and real output
//    https://www.fftw.org/fftw3_doc/The-Halfcomplex_002dformat-DFT.html

#include <stdio.h>
#include <stdlib.h>
#include <fftw3.h>
#include <omp.h>

static fftw_plan *fftw_plans = NULL; /* will point to array of plans per thread */
static double *fftw_in = NULL;
static double *fftw_out = NULL;

#pragma omp threadprivate(fftw_plans,fftw_in,fftw_out)

void fftw_initialisation_dirichlet(int Lmax, int J0){
  fftw_plans = (fftw_plan*) malloc((Lmax+1)*sizeof(fftw_plan));
  fftw_in    = fftw_alloc_real(J0<<Lmax);
  fftw_out   = fftw_alloc_real(J0<<Lmax);

#pragma omp critical
  for (int l=0; l<=Lmax; l++) {
    int n = (J0<<l) - 1;
    fftw_plans[l] = fftw_plan_r2r_1d(n,fftw_in,fftw_out,FFTW_RODFT00,FFTW_ESTIMATE);
    if (!fftw_plans[l]) {
      fprintf(stderr, "FFTW plan creation failure for n=%d \n", n);
      exit(EXIT_FAILURE); 
    }
  }
}

void fftw_initialisation_periodic(int Lmax, int J0){
  fftw_plans = (fftw_plan*) malloc((Lmax+1)*sizeof(fftw_plan));
  fftw_in    = fftw_alloc_real(J0<<Lmax);
  fftw_out   = fftw_alloc_real(J0<<Lmax);

#pragma omp critical
  for (int l=0; l<=Lmax; l++) {
    int n = (J0<<l);
    fftw_plans[l] = fftw_plan_r2r_1d(n,fftw_in,fftw_out,FFTW_HC2R,FFTW_ESTIMATE);
    if (!fftw_plans[l]) {
      fprintf(stderr, "FFTW plan creation failure for n=%d \n", n);
      exit(EXIT_FAILURE); 
    }
  }
}

void fftw_termination(int Lmax){
  for (int l=0; l<=Lmax; l++)
    if (fftw_plans[l]) fftw_destroy_plan(fftw_plans[l]);

  free(fftw_plans);
  fftw_free(fftw_in);
  fftw_free(fftw_out);
}

void fftw_calc(int l){
  fftw_execute(fftw_plans[l]);
}

