#ifndef SOPT_IMAGING_FORWARD_BACKWARD_H
#define SOPT_IMAGING_FORWARD_BACKWARD_H

#include "sopt/config.h"
#include <limits> // for std::numeric_limits<>
#include <memory> // for std::shared_ptr<>
#include <numeric>
#include <tuple>
#include <utility>
#include "sopt/exception.h"
#include "sopt/forward_backward.h"
#include "sopt/linear_transform.h"
#include "sopt/logging.h"
#include "sopt/proximal.h"
#include "sopt/relative_variation.h"
#include "sopt/types.h"
#include "sopt/non_differentiable_func.h"
#include "sopt/differentiable_func.h"

#include <functional>
#include "sopt/gradient_utils.h"
#include <stdexcept>

#ifdef SOPT_MPI
#include "sopt/mpi/communicator.h"
#include "sopt/mpi/utilities.h"
#endif

namespace sopt::algorithm {
template <typename SCALAR>
class ImagingForwardBackward {
  //! Underlying algorithm
  using FB = ForwardBackward<SCALAR>;

 public:
  using value_type = typename FB::value_type;
  using Scalar = typename FB::Scalar;
  using Real = typename FB::Real;
  using t_Vector = typename FB::t_Vector;
  using t_LinearTransform = typename FB::t_LinearTransform;
  using t_Proximal = typename FB::t_Proximal;
  using t_Gradient = typename FB::t_Gradient;
  using t_l2Gradient = typename std::function<void(t_Vector &, const t_Vector &)>;
  using t_IsConverged = typename FB::t_IsConverged;
  using t_randomUpdater = typename FB::t_randomUpdater;

  //! Values indicating how the algorithm ran
  struct Diagnostic : public FB::Diagnostic {
    Diagnostic(t_uint niters = 0u, bool good = false) : FB::Diagnostic(niters, good) {}
    Diagnostic(t_uint niters, bool good, t_Vector &&residual)
      : FB::Diagnostic(niters, good, std::move(residual)) {}
  };

  //! Holds result vector as well
  struct DiagnosticAndResult : public Diagnostic {
    //! Output x
    t_Vector x;
  };

  //! Sets up imaging wrapper for ForwardBackward. Sets g_function_ to null to avoid
  //! having a dependency on the implementation of g_function. The correct implementation
  //! should be injected by the code that instantiates this class.
  // Note: Using setter injection instead of constructior injection to follow the
  // style in the rest of the class, although constructor might be more appropriate
  // In this problem we assume an objective function \f$f(x, y, \Phi) + g(x)\f$ where 
  // \f$f\f$ is differentiable with a supplied gradient and \f$g\f$ is non-differentiable with a supplied proximal operator.
  // Throughout this class we will use `f` and `g` in variables to refer to these two parts of the objective function.
  //! \param[in] target: Vector of target measurements
  ImagingForwardBackward(t_Vector const &target)
      : g_function_(nullptr),
        f_function_(nullptr),
        random_updater_(nullptr),
        tight_frame_(false),
        residual_tolerance_(0.),
        relative_variation_(1e-4),
        residual_convergence_(nullptr),
        objective_convergence_(nullptr),
        itermax_(std::numeric_limits<t_uint>::max()),
        regulariser_strength_(1e-8),
        step_size_(1),
        sigma_(1),
        fista_(true),
        is_converged_() 
        {
          std::shared_ptr<t_LinearTransform> Id = std::make_shared<t_LinearTransform>(linear_transform_identity<Scalar>());
          problem_state = std::make_shared<IterationState<t_Vector>>(target, Id);
        }

  ImagingForwardBackward(t_randomUpdater &updater)
      : g_function_(nullptr),
        f_function_(nullptr),
        random_updater_(updater),
        tight_frame_(false),
        residual_tolerance_(0.),
        relative_variation_(1e-4),
        residual_convergence_(nullptr),
        objective_convergence_(nullptr),
        itermax_(std::numeric_limits<t_uint>::max()),
        regulariser_strength_(1e-8),
        step_size_(1),
        sigma_(1),
        fista_(true),
        is_converged_() 
        {
          if(random_updater_)
          {
            // target and Phi are not known ahead of time for random data sets so need to be initialised
            problem_state = random_updater_();
          } 
          else
          {
            throw std::runtime_error("Attempted to construct ImagingForwardBackward class with a null random updater. To run without random updates supply a target vector instead.");
          } 
        }

  virtual ~ImagingForwardBackward() {}
// Macro helps define properties that can be initialized as in
// auto padmm = ImagingForwardBackward<float>().prop0(value).prop1(value);
#define SOPT_MACRO(NAME, TYPE)                             \
  TYPE const &NAME() const { return NAME##_; }             \
  ImagingForwardBackward<SCALAR> &NAME(TYPE const &(NAME)) { \
    NAME##_ = NAME;                                        \
    return *this;                                          \
  }                                                        \
                                                           \
 protected:                                                \
  TYPE NAME##_;                                            \
                                                           \
 public:

  //! Whether Ψ is a tight-frame or not
  SOPT_MACRO(tight_frame, bool);
  //! \brief Convergence of the relative variation of the objective functions
  //! \details If negative, this convergence criteria is disabled.
  SOPT_MACRO(residual_tolerance, Real);
  //! \brief Convergence of the relative variation of the objective functions
  //! \details If negative, this convergence criteria is disabled.
  SOPT_MACRO(relative_variation, Real);
  //! \brief Convergence of the residuals
  //! \details If negative, this convergence criteria is disabled.
  SOPT_MACRO(residual_convergence, t_IsConverged);
  //! \brief Convergence of the residuals
  //! \details If negative, this convergence criteria is disabled.
  SOPT_MACRO(objective_convergence, t_IsConverged);
  //! Maximum number of iterations
  SOPT_MACRO(itermax, t_uint);
  //! γ parameter
  SOPT_MACRO(regulariser_strength, Real);
  //! γ parameter
  SOPT_MACRO(step_size, Real);
  //! γ parameter
  SOPT_MACRO(sigma, Real);
  //! flag to for FISTA Forward-Backward algorithm. True by default but should be false when using a learned g_function.
  SOPT_MACRO(fista, bool);
  //! A function verifying convergence
  SOPT_MACRO(is_converged, t_IsConverged);
  
  //! Measurement operator
  t_LinearTransform const &Phi() const { return problem_state->Phi(); }
  ImagingForwardBackward<SCALAR> &Phi(t_LinearTransform const &(Phi)) {
    problem_state->Phi(Phi);
    return *this;
  }

#ifdef SOPT_MPI
  //! Communicator for summing objective_function
  SOPT_MACRO(obj_comm, mpi::Communicator);
  SOPT_MACRO(adjoint_space_comm, mpi::Communicator);
#endif

#undef SOPT_MACRO

  // Getter and setter for the g_function object
  // The getter of g_function can not return a const because it will be used
  // to call setters of its internal properties
  std::shared_ptr<NonDifferentiableFunc<SCALAR>> g_function() { return g_function_; }
  ImagingForwardBackward<SCALAR>& g_function( std::shared_ptr<NonDifferentiableFunc<SCALAR>> g_function) {
    g_function_ = std::move(g_function);
    return *this;
  }

  // Getter and setter for the f_function object
  // The getter of f_function can not return a const because it will be used
  // to call setters of its internal properties
  std::shared_ptr<DifferentiableFunc<SCALAR>> f_function() { return f_function_; }
  ImagingForwardBackward<SCALAR>& f_function( std::shared_ptr<DifferentiableFunc<SCALAR>> f_function) {
    f_function_ = std::move(f_function);
    return *this;
  }

  // Getter and setter for the random updater object
  t_randomUpdater &random_updater() { return random_updater_; }
  ImagingForwardBackward<SCALAR>& random_updater( t_randomUpdater &new_updater) {
    random_updater_ = new_updater;  // may change this to a move if we don't need to keep it
    return *this;
  }

  t_LinearTransform const &Psi() const
  {
    return (g_function_) ? g_function_->Psi() : linear_transform_identity<Scalar>();
  }

  // Default f_gradient is gradient of l2-norm
  // This gradient ignores x and is based only on residual. (x is required for other forms of gradient)
  //t_Gradient f_gradient;

  //void set_f_gradient(const t_Gradient &fgrad)
  //{
  //  f_gradient = fgrad;
  //}

  //! Vector of target measurements
  t_Vector const &target() const { return problem_state->target(); }

  //! Minimum of objective_function
  Real objmin() const { return objmin_; }
  //! Sets the vector of target measurements
  ImagingForwardBackward<Scalar> &target(t_Vector const &target) {
    problem_state->target(target);
    return *this;
  }

  //! \brief Calls Forward Backward
  //! \param[out] out: Output vector x
  Diagnostic operator()(t_Vector &out) const {
    return operator()(out, ForwardBackward<SCALAR>::initial_guess(target(), Phi()));

  }
  //! \brief Calls Forward Backward
  //! \param[out] out: Output vector x
  //! \param[in] guess: initial guess
  Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess)  {
    return operator()(out, std::get<0>(guess), std::get<1>(guess));
  }
  //! \brief Calls Forward Backward
  //! \param[out] out: Output vector x
  //! \param[in] guess: initial guess
  Diagnostic operator()(t_Vector &out,
                        std::tuple<t_Vector const &, t_Vector const &> const &guess)  {
    return operator()(out, std::get<0>(guess), std::get<1>(guess));
  }
  //! \brief Calls Forward Backward
  //! \param[in] guess: initial guess
  DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess)  {
    return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
  }
  //! \brief Calls Forward Backward
  //! \param[in] guess: initial guess
  DiagnosticAndResult operator()(
      std::tuple<t_Vector const &, t_Vector const &> const &guess)  {
    DiagnosticAndResult result;
    static_cast<Diagnostic &>(result) = operator()(result.x, guess);
    return result;
  }
  //! \brief Calls Forward Backward
  //! \param[in] guess: initial guess
  DiagnosticAndResult operator()()  {
    DiagnosticAndResult result;
    static_cast<Diagnostic &>(result) = operator()(
        result.x, ForwardBackward<SCALAR>::initial_guess(target(), Phi()));
    return result;
  }
  //! Makes it simple to chain different calls to FB
  DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart)  {
    DiagnosticAndResult result = warmstart;
    static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
    return result;
  }

  //! Set Φ and Φ^† using arguments that sopt::linear_transform understands
  template <typename... ARGS>
  typename std::enable_if<sizeof...(ARGS) >= 1, ImagingForwardBackward &>::type Phi(
      ARGS &&... args) {
    problem_state->Phi(linear_transform(std::forward<ARGS>(args)...));
    return *this;
  }

  //! Helper function to set-up default residual convergence function
  ImagingForwardBackward<Scalar> &residual_convergence(Real const &tolerance) {
    return residual_convergence(nullptr).residual_tolerance(tolerance);
  }
  //! Helper function to set-up default residual convergence function
  ImagingForwardBackward<Scalar> &objective_convergence(Real const &tolerance) {
    return objective_convergence(nullptr).relative_variation(tolerance);
  }
  //! Convergence function that takes only the output as argument
  ImagingForwardBackward<Scalar> &is_converged(std::function<bool(t_Vector const &x)> const &func) {
    return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
  }

 protected:

  // Store a pointer of the abstract base classes DifferentiableFunc & NonDifferentiableFunction type for f and g
  // These should point to an instance of a derived class (e.g. L1GProximal) once set up
  std::shared_ptr<NonDifferentiableFunc<SCALAR>> g_function_;
  std::shared_ptr<DifferentiableFunc<SCALAR>> f_function_;
  t_randomUpdater random_updater_;
  //! Problem state represents Phi and y s.t. the problem to solve is y = Phi x
  std::shared_ptr<IterationState<t_Vector>> problem_state;

  //! Mininum of objective function
  mutable Real objmin_;

  //! \brief Calls Forward Backward
  //! \param[out] out: Output vector x
  //! \param[in] guess: initial guess
  //! \param[in] residuals: initial residuals
  Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) ;

  //! Helper function to simplify checking convergence
  bool residual_convergence(t_Vector const &x, t_Vector const &residual) const;

  //! Helper function to simplify checking convergence
  bool objective_convergence(ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
                             t_Vector const &residual) const;
#ifdef SOPT_MPI
  //! Helper function to simplify checking convergence
  bool objective_convergence(mpi::Communicator const &obj_comm,
                             ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
                             t_Vector const &residual) const;
#endif

  //! Helper function to simplify checking convergence
  bool is_converged(ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
                    t_Vector const &residual) const;
};

template <typename SCALAR>
typename ImagingForwardBackward<SCALAR>::Diagnostic ImagingForwardBackward<SCALAR>::operator()(
    t_Vector &out, t_Vector const &guess, t_Vector const &res) {
  if(!g_function_)
  {
    throw std::runtime_error("Non-differentiable function `g` has not been set. You must set it with `g_function()` before calling the algorithm.");
  }
  g_function_->log_message();
  Diagnostic result;
  auto const g_proximal = g_function_->proximal_operator();
  t_Gradient f_gradient;
  Real gradient_step_size;
  if(f_function_)
  {
    f_gradient = f_function_->gradient();
    gradient_step_size = f_function_->get_step_size();
  }
  if(!f_gradient)
  {
    SOPT_MEDIUM_LOG("Gradient function has not been set; using default (gaussian likelihood) gradient. (To set a custom gradient set_gradient() must be called before the algorithm is run.)");
    f_gradient = [this](t_Vector &output, t_Vector const &x, t_Vector const &residual, t_LinearTransform const &Phi) {
      output = Phi.adjoint() * (residual / (this->sigma() * this->sigma()));
    };
    gradient_step_size = sigma()*sigma();
  }
  ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
                                          "Objective function");
  auto const convergence = [this, &scalvar](t_Vector const &x, t_Vector const &residual) mutable {
    const bool result = this->is_converged(scalvar, x, residual);
    this->objmin_ = std::real(scalvar.previous());
    return result;
  };
  auto fb = ForwardBackward<SCALAR>(f_gradient, g_proximal, target())
                      .itermax(itermax())
                      .step_size(gradient_step_size)
                      .regulariser_strength(regulariser_strength())
                      .fista(fista())
                      .Phi(Phi())
                      .is_converged(convergence)
                      .random_updater(random_updater_)
                      .set_problem_state(problem_state);
  static_cast<typename ForwardBackward<SCALAR>::Diagnostic &>(result) =
      fb(out, std::tie(guess, res));
  return result;
}

template <typename SCALAR>
bool ImagingForwardBackward<SCALAR>::residual_convergence(t_Vector const &x,
                                                          t_Vector const &residual) const {
  if (static_cast<bool>(residual_convergence())) return residual_convergence()(x, residual);
  if (residual_tolerance() <= 0e0) return true;
  auto const residual_norm = sopt::l2_norm(residual);
  SOPT_LOW_LOG("    - [FB] Residuals: {} <? {}", residual_norm, residual_tolerance());
  return residual_norm < residual_tolerance();
}

template <typename SCALAR>
bool ImagingForwardBackward<SCALAR>::objective_convergence(ScalarRelativeVariation<Scalar> &scalvar,
                                                           t_Vector const &x,
                                                           t_Vector const &residual) const {
  if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
  if (scalvar.relative_tolerance() <= 0e0) return true;
  auto const current = ((regulariser_strength() > 0) ? g_function_->function(x)
			* regulariser_strength() : 0) + \
      ((f_function_) ? f_function_->function(x, target(), Phi()) : std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma()));
  return scalvar(current);
}

#ifdef SOPT_MPI
template <typename SCALAR>
bool ImagingForwardBackward<SCALAR>::objective_convergence(mpi::Communicator const &obj_comm,
                                                           ScalarRelativeVariation<Scalar> &scalvar,
                                                           t_Vector const &x,
                                                           t_Vector const &residual) const {
  if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
  if (scalvar.relative_tolerance() <= 0e0) return true;
  auto const current = obj_comm.all_sum_all<t_real>(
	((regulariser_strength() > 0) ? g_function_->function(x)
       * regulariser_strength() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma_ * sigma_));
  return scalvar(current);
}
#endif

template <typename SCALAR>
bool ImagingForwardBackward<SCALAR>::is_converged(ScalarRelativeVariation<Scalar> &scalvar,
                                                  t_Vector const &x,
                                                  t_Vector const &residual) const {
  auto const user = static_cast<bool>(is_converged()) == false or is_converged()(x, residual);
  auto const res = residual_convergence(x, residual);
#ifdef SOPT_MPI
  auto const obj = objective_convergence(obj_comm(), scalvar, x, residual);
#else
  auto const obj = objective_convergence(scalvar, x, residual);
#endif
  // beware of short-circuiting!
  // better evaluate each convergence function everytime, especially with mpi
  return user and res and obj;
}
} // namespace sopt::algorithm
#endif
