/*******************************************************************************
* Copyright 2014-2020 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file TestSymmetry.cpp

 HPCG routine
 */

// The MPI include must be first for Windows platforms
#ifndef HPCG_NO_MPI
#include <mpi.h>
#endif
#include <fstream>
#include <iostream>
#include <cfloat>
using std::endl;
#include <vector>
#include <cmath>

#include "hpcg.hpp"

#include "ComputeSPMV.hpp"
#include "ComputeSPMV_ref.hpp"
#include "ComputeMG.hpp"
#include "ComputeMG_ref.hpp"
#include "ComputeDotProduct.hpp"
#include "ComputeResidual.hpp"
#include "Geometry.hpp"
#include "SparseMatrix.hpp"
#include "TestSymmetry.hpp"

// We typically compile with optimizations (fastmath) that make
// std::isnan always return false. But string-ifying a nan bit pattern
// still produces "nan" which we exploit in this nan-checking workaround
bool str_isnan(double x) {
    std::stringstream ss;
    ss << x;
    if (ss.str() == "nan") {
        return true;
    }
    return false;
}

/*!
  Tests symmetry-preserving properties of the sparse matrix vector multiply and multi-grid routines.

  @param[in]    geom   The description of the problem's geometry.
  @param[in]    A      The known system matrix
  @param[in]    b      The known right hand side vector
  @param[in]    xexact The exact solution vector
  @param[inout] testsymmetry_data The data structure with the results of the CG symmetry test including pass/fail information

  @return returns 0 upon success and non-zero otherwise

  @see ComputeDotProduct
  @see ComputeDotProduct_ref
  @see ComputeSPMV
  @see ComputeSPMV_ref
  @see ComputeMG
  @see ComputeMG_ref
*/
int TestSymmetry(SparseMatrix & A, Vector & b, Vector & xexact, TestSymmetryData & testsymmetry_data, sycl::queue & main_queue) {

 local_int_t nrow = A.localNumberOfRows;
 local_int_t ncol = A.localNumberOfColumns;

 Vector x_ncol, y_ncol, z_ncol;
 InitializeVectorShared(x_ncol, ncol, main_queue);
 InitializeVectorShared(y_ncol, ncol, main_queue);
 InitializeVectorShared(z_ncol, ncol, main_queue);

 double *xNorm2 = (double *)sparse_malloc_shared(sizeof(double) * 1, main_queue);
 double *yNorm2 = (double *)sparse_malloc_shared(sizeof(double) * 1, main_queue);
 double *xtAy = (double *)sparse_malloc_shared(sizeof(double) * 1, main_queue);
 double *ytAx = (double *)sparse_malloc_shared(sizeof(double) * 1, main_queue);
 double *xtMinvy = (double *)sparse_malloc_shared(sizeof(double) * 1, main_queue);
 double *ytMinvx = (double *)sparse_malloc_shared(sizeof(double) * 1, main_queue);

 double t4 = 0.0; // Needed for dot-product call, otherwise unused
 testsymmetry_data.count_fail = 0;

 // Test symmetry of matrix

 // First load vectors with random values
 FillRandomVector(x_ncol);
 FillRandomVector(y_ncol);

 double ANorm = 2 * 26.0;

 // Next, compute x'*A*y
 auto ev = ComputeDotProduct(nrow, y_ncol, y_ncol, yNorm2, t4, main_queue);
 int ierr = 0;
 ComputeSPMV(A, y_ncol, z_ncol, main_queue, ierr, {ev}).wait(); // z_nrow = A*y_overlap
 if (ierr) HPCG_fout << "Error in call to SpMV: " << ierr << ".\n" << endl;
 ComputeDotProduct(nrow, x_ncol, z_ncol, xtAy, t4, main_queue).wait();
 if (ierr) HPCG_fout << "Error in call to dot: " << ierr << ".\n" << endl;

 // Next, compute y'*A*x
 ComputeDotProduct(nrow, x_ncol, x_ncol, xNorm2, t4, main_queue).wait();
 ComputeSPMV(A, x_ncol, z_ncol, main_queue, ierr).wait(); // b_computed = A*x_overlap
 if (ierr) HPCG_fout << "Error in call to SpMV: " << ierr << ".\n" << endl;
 ComputeDotProduct(nrow, y_ncol, z_ncol, ytAx, t4, main_queue).wait();
 if (ierr) HPCG_fout << "Error in call to dot: " << ierr << ".\n" << endl;

 testsymmetry_data.depsym_spmv = std::fabs((long double) (xtAy[0] - ytAx[0]))/((xNorm2[0]*ANorm*yNorm2[0] + yNorm2[0]*ANorm*xNorm2[0]) * (DBL_EPSILON));

 if (str_isnan(testsymmetry_data.depsym_spmv) || testsymmetry_data.depsym_spmv > 1.0) ++testsymmetry_data.count_fail;  // If the difference is > 1, count it wrong
 if (A.geom->rank==0) HPCG_fout << "Departure from symmetry (scaled) for SpMV abs(x'*A*y - y'*A*x) = " << testsymmetry_data.depsym_spmv << endl;

 // Test symmetry of multi-grid

 // Compute x'*Minv*y
 ComputeMG(A, y_ncol, z_ncol, main_queue, ierr).wait(); // z_ncol = Minv*y_ncol
 if (ierr) HPCG_fout << "Error in call to MG: " << ierr << ".\n" << endl;
 ComputeDotProduct(nrow, x_ncol, z_ncol, xtMinvy, t4, main_queue).wait();
 if (ierr) HPCG_fout << "Error in call to dot: " << ierr << ".\n" << endl;

 // Next, compute z'*Minv*x
 ComputeMG(A, x_ncol, z_ncol, main_queue, ierr).wait(); // z_ncol = Minv*x_ncol
 if (ierr) HPCG_fout << "Error in call to MG: " << ierr << ".\n" << endl;
 ComputeDotProduct(nrow, y_ncol, z_ncol, ytMinvx, t4, main_queue).wait();
 if (ierr) HPCG_fout << "Error in call to dot: " << ierr << ".\n" << endl;

 testsymmetry_data.depsym_mg = std::fabs((long double) (xtMinvy[0] - ytMinvx[0]))/((xNorm2[0]*ANorm*yNorm2[0] + yNorm2[0]*ANorm*xNorm2[0]) * (DBL_EPSILON));
 if (str_isnan(testsymmetry_data.depsym_mg) || testsymmetry_data.depsym_mg > 1.0) ++testsymmetry_data.count_fail;  // If the difference is > 1, count it wrong
 if (A.geom->rank==0) HPCG_fout << "Departure from symmetry (scaled) for MG abs(x'*Minv*y - y'*Minv*x) = " << testsymmetry_data.depsym_mg << endl;

 CopyVector(xexact, x_ncol, main_queue).wait(); // Copy exact answer into overlap vector

 int numberOfCalls = 2;
 double residual = 0.0;
 for (int i=0; i< numberOfCalls; ++i) {
   ComputeSPMV(A, x_ncol, z_ncol, main_queue, ierr).wait(); // b_computed = A*x_overlap
   if (ierr) HPCG_fout << "Error in call to SpMV: " << ierr << ".\n" << endl;
   ComputeResidual(A.localNumberOfRows, b, z_ncol, residual, ierr, main_queue).wait();
   if (ierr) HPCG_fout << "Error in call to compute_residual: " << ierr << ".\n" << endl;
   if (A.geom->rank==0) HPCG_fout << "SpMV call [" << i << "] Residual [" << residual << "]" << endl;
 }
 DeleteVector(x_ncol, main_queue);
 DeleteVector(y_ncol, main_queue);
 DeleteVector(z_ncol, main_queue);
 sycl::free(xNorm2, main_queue.get_context());
 sycl::free(yNorm2, main_queue.get_context());
 sycl::free(xtAy, main_queue.get_context());
 sycl::free(ytAx, main_queue.get_context());
 sycl::free(xtMinvy, main_queue.get_context());
 sycl::free(ytMinvx, main_queue.get_context());

 return 0;
}

