//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Sim/Scan/AlphaScan.cpp
//! @brief     Implements AlphaScan class.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Sim/Scan/AlphaScan.h"
#include "Base/Axis/MakeScale.h"
#include "Base/Axis/Scale.h"
#include "Device/Beam/IFootprint.h"
#include "Device/Coord/CoordSystem1D.h"
#include "Device/Pol/PolFilter.h"
#include "Param/Distrib/Distributions.h"
#include "Param/Distrib/ParameterSample.h"
#include "Resample/Element/SpecularElement.h"
#include <algorithm>
#include <numbers>

using std::numbers::pi;

namespace {

size_t nResolSamples(const IDistribution1D* distrib)
{
    return distrib ? distrib->nSamples() : 1L;
}

std::vector<ParameterSample> drawDistribution(const IDistribution1D* distrib)
{
    if (!distrib)
        return {{0., 1}};
    return distrib->distributionSamples();
}

} // namespace


AlphaScan::AlphaScan(const Scale& alpha_axis)
    : IBeamScan(alpha_axis.clone(), 0.0)
{
    checkInitialization();
}

AlphaScan::AlphaScan(int nbins, double alpha_i_min, double alpha_i_max)
    : AlphaScan(EquiScan("alpha_i", nbins, alpha_i_min, alpha_i_max))
{
}

AlphaScan::~AlphaScan() = default;

AlphaScan* AlphaScan::clone() const
{
    auto* result = new AlphaScan(*m_axis);
    result->setIntensity(intensity());
    result->setFootprint(m_footprint.get());
    result->setAlphaOffset(alphaOffset());

    if (m_lambda_distrib)
        result->m_lambda_distrib.reset(m_lambda_distrib->clone());
    else
        result->setWavelength(wavelength());

    if (m_alpha_distrib)
        result->m_alpha_distrib.reset(m_alpha_distrib->clone());

    if (m_beamPolarization)
        result->m_beamPolarization.reset(new R3(*m_beamPolarization));
    if (m_polAnalyzer)
        result->m_polAnalyzer.reset(new PolFilter(*m_polAnalyzer));

    return result;
}

std::vector<const INode*> AlphaScan::nodeChildren() const
{
    std::vector<const INode*> result;
    for (const INode* n : IBeamScan::nodeChildren())
        result << n;
    if (m_lambda_distrib)
        result << m_lambda_distrib.get();
    if (m_alpha_distrib)
        result << m_alpha_distrib.get();
    return result;
}

std::vector<SpecularElement> AlphaScan::generateElements() const
{
    std::vector<SpecularElement> result;
    result.reserve(nScan() * nDistributionSamples());

    if (!m_lambda_distrib && wavelength() <= 0)
        throw std::runtime_error("Specular scan has neither wavelength > 0 nor distribution");

    for (size_t i = 0; i < m_axis->size(); ++i) {
        const std::vector<ParameterSample> alphaDistrib = drawDistribution(m_alpha_distrib.get());
        const std::vector<ParameterSample> lambdaDistrib = drawDistribution(m_lambda_distrib.get());
        for (size_t j = 0; j < alphaDistrib.size(); ++j) {
            const double alpha = m_axis->binCenters()[i] + alphaDistrib[j].value + m_alpha_offset;
            for (size_t k = 0; k < lambdaDistrib.size(); ++k) {
                const double lambda = m_lambda_distrib ? lambdaDistrib[k].value : wavelength();
                const bool computable = lambda >= 0 && alpha >= 0 && alpha <= (pi / 2);
                const double weight = alphaDistrib[j].weight * lambdaDistrib[k].weight;
                const double footprint = m_footprint ? m_footprint->calculate(alpha) : 1;
                result.emplace_back(SpecularElement::FromAlphaScan(i, weight, lambda, -alpha,
                                                                   footprint, polarizerMatrix(),
                                                                   analyzerMatrix(), computable));
            }
        }
    }
    return result;
}

void AlphaScan::setWavelength(double lambda)
{
    if (m_lambda_distrib)
        throw std::runtime_error("AlphaScan: wavelength already set through distribution");
    if (lambda <= 0)
        throw std::runtime_error("AlphaScan: wavelength must be set to positive value");
    m_lambda0 = lambda;
}

void AlphaScan::setWavelengthDistribution(const IDistribution1D& distr)
{
    if (distr.mean() <= 0)
        throw std::runtime_error("AlphaScan: mean wavelength must be set to positive value");
    if (wavelength() != 0)
        throw std::runtime_error("AlphaScan does not allow wavelength distribution "
                                 "as explicit wavelength has been set");
    m_lambda_distrib.reset(distr.clone());
}

void AlphaScan::setAngleDistribution(const IDistribution1D& distr)
{
    m_alpha_distrib.reset(distr.clone());
}

size_t AlphaScan::nDistributionSamples() const
{
    return nResolSamples(m_lambda_distrib.get()) * nResolSamples(m_alpha_distrib.get());
}

CoordSystem1D* AlphaScan::scanCoordSystem() const
{
    return new AngularReflectometryCoords(wavelength(), *coordinateAxis());
}

void AlphaScan::checkInitialization()
{
    const std::vector<double> axis_values = m_axis->binCenters();
    if (!std::is_sorted(axis_values.begin(), axis_values.end()))
        throw std::runtime_error("AlphaScan called with invalid alpha_i vector:"
                                 " is not sorted in ascending order");

    // TODO: check for inclination angle limits after switching to pointwise resolution.
}
