#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#ifdef _OPENMP
#include <omp.h>
#endif

#define M_PI	3.14159265358979323846	/* pi */

/* C program to compute a finite difference
   approximation to the 2D heat equation as
   a test program for OpenMP parallelisation */

int main()
{
  int    m, n, mid;
  double sig2, t, dt, dx, nu;

  double *restrict v1, *restrict v2;   // note use of *restrict
    
  /* dynamic memory allocation */

  m  = 1001;
  v1 = (double *) malloc(m*m*sizeof(double));
  v2 = (double *) malloc(m*m*sizeof(double));

  /* initialisation */

#ifdef _OPENMP
  double wtime = omp_get_wtime();
#endif
  
  sig2 = 0.025;

  dx  = 1.0 / (m-1);
  dt  = 0.2*dx*dx/sig2;
  nu  = sig2*dt/(dx*dx);
  t   = 1.0;
  n   = t/dt;
  n   = 100*(n/100);  // rounds to a multiple of 100
  dt  = t/n;

  mid = m/2;

  printf("\n#timesteps = %d\n\n",n);

#pragma omp parallel for default(none) shared(m,v1,v2,dx)
  for (int j=0; j<m; j++) {
    for (int i=0; i<m; i++) {
      v1[i+j*m] = sin(M_PI*i*dx) * sin(M_PI*j*dx);
      v2[i+j*m] = 0.0;
    }
  }
  
  printf("time, mid-point value = %f, %f\n",0.0,v1[mid+mid*m]);

  /* do time-stepping */

  for (int nn=1; nn<=n; nn++) {
#pragma omp parallel for default(none) shared(m,v1,v2,nu)
    for (int j=1; j<m-1; j++) {
      for (int i=1; i<m-1; i++) {
	int ind = i + j*m;
        v2[ind] = (1.0-4.0*nu)*v1[ind]                   
                + nu*(v1[ind+1] + v1[ind-1] + v1[ind+m] + v1[ind-m]);
      }
    }
    
    // swap pointers so new solution is now in v1
    double *tmp=v1; v1=v2; v2=tmp;

    if (nn%(n/10) == 0)
      printf("time, mid-point value = %f, %f\n",nn*dt,v1[mid+mid*m]);
  }

  double flops  =  6.0*m*m*n;  // 6 flops per inner grid point
  double bytes1 = 48.0*m*m*n;  // vector load/store for inner loop
  double bytes2 = 16.0*m*m*n;  // load/store v1, v2 for each timestep
  
  printf("\nflops executed     = %10.4g\n",flops);
  printf("bytes (L1->regs)   = %10.4g\n",bytes1);
  printf("bytes (L2->L1)     = %10.4g\n\n",bytes2);

#ifdef _OPENMP
  wtime = omp_get_wtime() - wtime;
  printf("threads            = %d\n",omp_get_max_threads());
  printf("execution time     = %10.4g\n",wtime);
  printf("flops/s            = %10.4g\n",flops/wtime);
  printf("bytes/s (L1->regs) = %10.4g\n",bytes1/wtime);
  printf("bytes/s (L2->L1)   = %10.4g\n",bytes2/wtime);
#endif
  return 0;
}

