//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Sim/Simulation/DepthprobeSimulation.cpp
//! @brief     Implements class DepthprobeSimulation.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Sim/Simulation/DepthprobeSimulation.h"
#include "Base/Axis/Frame.h"
#include "Base/Axis/Scale.h"
#include "Base/Const/PhysicalConstants.h"
#include "Base/Progress/ProgressHandler.h"
#include "Base/Util/Assert.h"
#include "Base/Vector/GisasDirection.h"
#include "Device/Beam/IFootprint.h"
#include "Device/Data/Datafield.h"
#include "Param/Distrib/DistributionHandler.h"
#include "Param/Distrib/Distributions.h"
#include "Resample/Element/IElement.h"
#include "Resample/Flux/ScalarFlux.h"
#include "Resample/Processed/ReSample.h"
#include "Sim/Scan/AlphaScan.h"
#include <valarray>

using PhysConsts::pi;

const int ZDirection_None = 0;
const int ZDirection_Reflected = 1;
const int ZDirection_Transmitted = 2;
const int WaveProperty_Intensity = 0;
const int WaveProperty_Modulus = 4;
const int WaveProperty_Phase = 8;

DepthprobeSimulation::DepthprobeSimulation(const BeamScan& scan, const Sample& sample,
                                           const Scale& zaxis, int flags)
    : ISimulation(sample)
    , m_scan(dynamic_cast<AlphaScan*>(scan.clone()))
    , m_z_axis(zaxis.clone())
    , m_flags(flags)
{
    if (!m_scan)
        throw std::runtime_error("DepthprobeSimulation not implemented for non-alpha scan");
}

DepthprobeSimulation::~DepthprobeSimulation() = default;

std::vector<const INode*> DepthprobeSimulation::nodeChildren() const
{
    std::vector<const INode*> result = ISimulation::nodeChildren();
    result.push_back(m_scan.get());
    return result;
}

//... Overridden executors:

//! init callbacks for setting the parameter values
void DepthprobeSimulation::initDistributionHandler()
{
    for (const auto& distribution : distributionHandler().paramDistributions()) {

        switch (distribution.whichParameter()) {
        case ParameterDistribution::BeamInclinationAngle: {
            distributionHandler().defineCallbackForDistribution(
                &distribution, [&](double d) { m_scan->setAlphaOffset(d); });
            break;
        }
        case ParameterDistribution::BeamWavelength:
            distributionHandler().defineCallbackForDistribution(
                &distribution, [&](double d) { m_scan->setWavelength(d); });
            break;
        default:
            ASSERT_NEVER;
        }
    }
}

void DepthprobeSimulation::runComputation(const ReSample& re_sample, size_t i, double weight)
{
    if (m_scan->wavelengthDistribution() || m_scan->alphaDistribution())
        throw std::runtime_error(
            "Depthprobe simulation supports neither alpha nor lambda distributions.");

    const size_t n_z = m_z_axis->size();
    std::valarray<double> intensities; //!< simulated intensity for given z positions
    intensities.resize(n_z, 0.0);

    const double result_angle = m_scan->coordinateAxis()->binCenter(i) + m_scan->alphaOffset();
    if (0 < result_angle && result_angle < (pi / 2)) {
        const size_t n_layers = re_sample.numberOfSlices();
        size_t start_z_ind = n_z;

        const R3 ki = vecOfLambdaAlphaPhi(m_scan->wavelengthAt(i), -result_angle);
        const Fluxes fluxes = re_sample.fluxesIn(ki);

        double z_layer_bottom(0.0);
        double z_layer_top(0.0);
        for (size_t i_layer = 0; i_layer < n_layers && start_z_ind != 0; ++i_layer) {
            z_layer_bottom = re_sample.avgeSlice(i_layer).low();
            z_layer_top = i_layer ? re_sample.avgeSlice(i_layer).hig() : z_layer_bottom;

            // get R & T coefficients for current layer
            const auto* flux = dynamic_cast<const ScalarFlux*>(fluxes[i_layer]);
            ASSERT(flux);
            const complex_t R = flux->getScalarR();
            const complex_t T = flux->getScalarT();
            const complex_t kz_out = flux->getScalarKz();
            const complex_t kz_in = -kz_out;

            // Compute intensity for z's of the layer
            size_t ip1_z = start_z_ind;
            for (; ip1_z > 0; --ip1_z) {
                const size_t i_z = ip1_z - 1;
                if (i_layer + 1 != n_layers && m_z_axis->binCenter(i_z) <= z_layer_bottom)
                    break;
                const double z = m_z_axis->binCenter(i_z) - z_layer_top;
                complex_t psi;
                if ((m_flags & 3) == ZDirection_None)
                    psi = R * exp_I(kz_out * z) + T * exp_I(kz_in * z);
                else if (m_flags & ZDirection_Reflected)
                    psi = R * exp_I(kz_out * z);
                else if (m_flags & ZDirection_Transmitted)
                    psi = T * exp_I(kz_in * z);
                else
                    throw std::runtime_error("Invalid combination of ZDirection flags");
                if ((m_flags & 12) == WaveProperty_Intensity)
                    intensities[i_z] = std::norm(psi);
                else if (m_flags & WaveProperty_Modulus)
                    intensities[i_z] = std::abs(psi);
                else if (m_flags & WaveProperty_Phase)
                    intensities[i_z] = std::arg(psi);
                else
                    throw std::runtime_error("Invalid combination of WaveProperty flags");
            }
            start_z_ind = ip1_z;
        }
    }

    double intensity_factor = m_scan->intensityAt(i);
    for (double& v : intensities)
        v *= intensity_factor;

    const size_t N1 = m_z_axis->size();
    for (size_t j = 0; j < N1; ++j)
        m_cache[j * m_scan->nScan() + i] += intensities[j] * weight;

    progress().incrementDone(1);
}

//... Overridden getters:

size_t DepthprobeSimulation::nElements() const
{
    return m_scan->coordinateAxis()->size();
}

size_t DepthprobeSimulation::nOutChannels() const
{
    return nElements() * m_z_axis->size();
}

Datafield DepthprobeSimulation::packResult()
{
    std::vector<const Scale*> axes{m_scan->coordinateAxis()->clone(), m_z_axis->clone()};
    auto data = std::make_unique<Datafield>(axes, m_cache);

    if (background())
        throw std::runtime_error("nonzero background is not supported by DepthprobeSimulation");

    return {*data};
}
