/*******************************************************
 * Copyright (c) 2014, ArrayFire
 * All rights reserved.
 *
 * This file is distributed under 3-clause BSD license.
 * The complete license agreement can be obtained at:
 * http://arrayfire.com/licenses/BSD-3-Clause
 ********************************************************/

#include <complex>
#include <af/dim4.hpp>
#include <af/defines.h>
#include <ArrayInfo.hpp>
#include <Array.hpp>
#include <ireduce.hpp>
#include <platform.hpp>
#include <queue.hpp>
#include <kernel/ireduce.hpp>

using af::dim4;

namespace cpu
{

template<af_op_t op, typename T>
using ireduce_dim_func = std::function<void(Array<T>, Array<uint>, const dim_t,
                                            const Array<T>, const dim_t, const int)>;

template<af_op_t op, typename T>
void ireduce(Array<T> &out, Array<uint> &loc, const Array<T> &in, const int dim)
{
    out.eval();
    loc.eval();
    in.eval();

    dim4 odims = in.dims();
    odims[dim] = 1;
    static const ireduce_dim_func<op, T> ireduce_funcs[] = { kernel::ireduce_dim<op, T, 1>()
                                                           , kernel::ireduce_dim<op, T, 2>()
                                                           , kernel::ireduce_dim<op, T, 3>()
                                                           , kernel::ireduce_dim<op, T, 4>()};

    getQueue().enqueue(ireduce_funcs[in.ndims() - 1], out, loc, 0, in, 0, dim);
}

template<af_op_t op, typename T>
T ireduce_all(unsigned *loc, const Array<T> &in)
{
    in.eval();
    getQueue().sync();

    af::dim4 dims = in.dims();
    af::dim4 strides = in.strides();
    const T *inPtr = in.get();

    kernel::MinMaxOp<op, T> Op(inPtr[0], 0);

    for(dim_t l = 0; l < dims[3]; l++) {
        dim_t off3 = l * strides[3];

        for(dim_t k = 0; k < dims[2]; k++) {
            dim_t off2 = k * strides[2];

            for(dim_t j = 0; j < dims[1]; j++) {
                dim_t off1 = j * strides[1];

                for(dim_t i = 0; i < dims[0]; i++) {
                    dim_t idx = i + off1 + off2 + off3;
                    Op(inPtr[idx], idx);
                }
            }
        }
    }

    *loc = Op.m_idx;
    return Op.m_val;
}

#define INSTANTIATE(ROp, T)                                             \
    template void ireduce<ROp, T>(Array<T> &out, Array<uint> &loc,      \
                                  const Array<T> &in, const int dim);   \
    template T ireduce_all<ROp, T>(unsigned *loc, const Array<T> &in);  \

//min
INSTANTIATE(af_min_t, float  )
INSTANTIATE(af_min_t, double )
INSTANTIATE(af_min_t, cfloat )
INSTANTIATE(af_min_t, cdouble)
INSTANTIATE(af_min_t, int    )
INSTANTIATE(af_min_t, uint   )
INSTANTIATE(af_min_t, intl   )
INSTANTIATE(af_min_t, uintl  )
INSTANTIATE(af_min_t, char   )
INSTANTIATE(af_min_t, uchar  )
INSTANTIATE(af_min_t, short  )
INSTANTIATE(af_min_t, ushort )

//max
INSTANTIATE(af_max_t, float  )
INSTANTIATE(af_max_t, double )
INSTANTIATE(af_max_t, cfloat )
INSTANTIATE(af_max_t, cdouble)
INSTANTIATE(af_max_t, int    )
INSTANTIATE(af_max_t, uint   )
INSTANTIATE(af_max_t, intl   )
INSTANTIATE(af_max_t, uintl  )
INSTANTIATE(af_max_t, char   )
INSTANTIATE(af_max_t, uchar  )
INSTANTIATE(af_max_t, short  )
INSTANTIATE(af_max_t, ushort )

}
