#pragma once

/////////////
// tuple.h //
/////////////

// why reinvent the wheel and implement a tuple class?
//  - ensure data is laid out in the same order the types are specified
//        see: https://github.com/EnzymeAD/Enzyme/issues/1191#issuecomment-1556239213
//  - CUDA compatibility: std::tuple has some compatibility issues when used
//        in a __device__ context (this may get better in c++20 with the improved
//        constexpr support for std::tuple). Owning the implementation lets
//        us add __host__ __device__ annotations to any part of it

#include <cstddef> // for std::size_t
#include <utility> // for std::integer_sequence

#include <enzyme/type_traits>

#define _NOEXCEPT noexcept
namespace enzyme {

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"

template <int i>
struct Index {};

template <int i, typename T>
struct value_at_position {
  __attribute__((always_inline))
  T & operator[](Index<i>) { return value; }

  __attribute__((always_inline))
  constexpr const T & operator[](Index<i>) const { return value; }
  T value;
};

template <typename S, typename... T>
struct tuple_base;

template <int... i, typename... T>
struct tuple_base<std::integer_sequence<int, i...>, T...>
    : public value_at_position<i, T>... {
    using value_at_position<i, T>::operator[]...;
};

template <typename... T>
struct tuple : public tuple_base<std::make_integer_sequence<int, sizeof...(T)>, T...> {};

template <typename... T>
__attribute__((always_inline))
tuple(T ...) -> tuple<T...>;

template < int i, typename Tuple >
__attribute__((always_inline))
decltype(auto) get(Tuple && tup) {
  constexpr bool is_lvalue = std::is_lvalue_reference_v<Tuple>;
  constexpr bool is_const = std::is_const_v<std::remove_reference_t<Tuple>>;
  using T = remove_cvref_t< decltype(tup[Index<i>{ } ]) >;
  if constexpr ( is_lvalue &&  is_const) { return static_cast<const T&>(tup[Index<i>{} ]); }
  if constexpr ( is_lvalue && !is_const) { return static_cast<T&>(tup[Index<i>{} ]); }
  if constexpr (!is_lvalue &&  is_const) { return static_cast<const T&&>(tup[Index<i>{} ]); }
  if constexpr (!is_lvalue && !is_const) { return static_cast<T&&>(tup[Index<i>{} ]); }
}

template < int i, typename ... T>
__attribute__((always_inline))
decltype(auto) get(const tuple< T ... > & tup) {
    return tup[Index<i>{} ];
}

template <typename Tuple>
struct tuple_size;

template <typename... T>
struct tuple_size<tuple<T...>> : std::integral_constant<std::size_t, sizeof...(T)> {};

template <typename Tuple>
static constexpr std::size_t tuple_size_v = tuple_size<Tuple>::value;

template <typename... T>
__attribute__((always_inline))
constexpr auto forward_as_tuple(T&&... args) noexcept {
  return tuple<T&&...>{impl::forward<T>(args)...};
}

namespace impl {

template <typename index_seq>
struct make_tuple_from_fwd_tuple;

template <std::size_t... indices>
struct make_tuple_from_fwd_tuple<std::index_sequence<indices...>> {
  template <typename FWD_TUPLE>
  __attribute__((always_inline))
  static constexpr auto f(FWD_TUPLE&& fwd) {
    return tuple{get<indices>(impl::forward<FWD_TUPLE>(fwd))...};
  }
};

template <typename FWD_INDEX_SEQ, typename TUPLE_INDEX_SEQ>
struct concat_with_fwd_tuple;

template < typename Tuple >
using iseq = std::make_index_sequence<tuple_size_v< enzyme::remove_cvref_t< Tuple > > >;

template <std::size_t... fwd_indices, std::size_t... indices>
struct concat_with_fwd_tuple<std::index_sequence<fwd_indices...>, std::index_sequence<indices...>> {
  template <typename FWD_TUPLE, typename TUPLE>
  __attribute__((always_inline))
  static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) {
    return enzyme::forward_as_tuple(get<fwd_indices>(impl::forward<FWD_TUPLE>(fwd))..., get<indices>(impl::forward<TUPLE>(t))...);
  }
};

template <typename Tuple>
__attribute__((always_inline))
static constexpr auto tuple_cat(Tuple&& ret) {
  return make_tuple_from_fwd_tuple< iseq< Tuple > >::f(impl::forward< Tuple >(ret));
}

template <typename FWD_TUPLE, typename first, typename... rest>
__attribute__((always_inline))
static constexpr auto tuple_cat(FWD_TUPLE&& fwd, first&& t, rest&&... ts) {
  return tuple_cat(concat_with_fwd_tuple< iseq<FWD_TUPLE>, iseq<first> >::f(impl::forward<FWD_TUPLE>(fwd), impl::forward<first>(t)), impl::forward<rest>(ts)...);
}

}  // namespace impl

template <typename... Tuples>
__attribute__((always_inline))
constexpr auto tuple_cat(Tuples&&... tuples) {
  return impl::tuple_cat(impl::forward<Tuples>(tuples)...);
}

#pragma clang diagnostic pop

} // namespace enzyme
#undef _NOEXCEPT
