/**
 * @file aug_lagrangian.hpp
 * @author Ryan Curtin
 *
 * Definition of AugLagrangian class, which implements the Augmented Lagrangian
 * optimization method (also called the 'method of multipliers'.  This class
 * uses the L-BFGS optimizer.
 *
 * ensmallen is free software; you may redistribute it and/or modify it under
 * the terms of the 3-clause BSD license.  You should have received a copy of
 * the 3-clause BSD license along with ensmallen.  If not, see
 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
 */

#ifndef ENSMALLEN_AUG_LAGRANGIAN_AUG_LAGRANGIAN_HPP
#define ENSMALLEN_AUG_LAGRANGIAN_AUG_LAGRANGIAN_HPP

#include <ensmallen_bits/lbfgs/lbfgs.hpp>

#include "aug_lagrangian_function.hpp"

namespace ens {

/**
 * The AugLagrangian class implements the Augmented Lagrangian method of
 * optimization.  In this scheme, a penalty term is added to the Lagrangian.
 * This method is also called the "method of multipliers".
 *
 * AugLagrangian can optimize constrained functions.  For more details, see the
 * documentation on function types included with this distribution or on the
 * ensmallen website.
 */
template<typename VecType = arma::vec> // TODO: remove for ensmallen 4.x
class AugLagrangianType
{
 public:
  /**
   * Initialize the Augmented Lagrangian with the default L-BFGS optimizer.
   * @param penaltyThresholdFactor When the penalty threshold is updated set
   *    the penalty threshold to the penalty multplied by this factor. The
   *    default value of 0.25 is is taken from Burer and Monteiro (2002).
   * @param sigmaUpdateFactor When sigma is updated  multiply sigma by this
   *    value. The default value of 10 is taken from Burer and Monteiro (2002).
   * @param maxIterations Maximum number of iterations of the Augmented
   *     Lagrangian algorithm.  0 indicates no maximum.
   */
  AugLagrangianType(const size_t maxIterations = 1000,
                    const double penaltyThresholdFactor = 0.25,
                    const double sigmaUpdateFactor = 10.0,
                    const L_BFGS& lbfgs = L_BFGS());

  /**
   * Optimize the function.  The value '0' is used for the initial value of each
   * Lagrange multiplier.  To set the Lagrange multipliers yourself, use the
   * other overload of Optimize().
   *
   * @tparam LagrangianFunctionType Function which can be optimized by this
   *     class.
   * @tparam MatType Type of matrix to optimize with.
   * @tparam GradType Type of matrix to use to represent function gradients.
   * @tparam CallbackTypes Types of callback functions.
   * @param function The function to optimize.
   * @param coordinates Output matrix to store the optimized coordinates in.
   * @param callbacks Callback functions.
   */
  template<typename LagrangianFunctionType,
           typename MatType,
           typename GradType,
           typename... CallbackTypes>
  typename std::enable_if<IsMatrixType<GradType>::value &&
                          IsAllNonMatrix<CallbackTypes...>::value, bool>::type
  Optimize(LagrangianFunctionType& function,
           MatType& coordinates,
           CallbackTypes&&... callbacks);

  //! Forward the MatType as GradType.
  template<typename LagrangianFunctionType,
           typename MatType,
           typename... CallbackTypes>
  typename std::enable_if<IsAllNonMatrix<CallbackTypes...>::value, bool>::type
  Optimize(LagrangianFunctionType& function,
           MatType& coordinates,
           CallbackTypes&&... callbacks)
  {
    return Optimize<LagrangianFunctionType, MatType, MatType,
        CallbackTypes...>(function, coordinates,
        std::forward<CallbackTypes>(callbacks)...);
  }

  /**
   * Optimize the function, giving initial estimates for the Lagrange
   * multipliers.  The vector of Lagrange multipliers will be modified to
   * contain the Lagrange multipliers of the final solution (if one is found).
   *
   * @tparam LagrangianFunctionType Function which can be optimized by this
   *      class.
   * @tparam MatType Type of matrix to optimize with.
   * @tparam GradType Type of matrix to use to represent function gradients.
   * @tparam CallbackTypes Types of callback functions.
   * @param function The function to optimize.
   * @param coordinates Output matrix to store the optimized coordinates in.
   * @param lambda Vector containing initial Lagrange multipliers.  Should have
   *     length equal to the number of constraints.  This will be overwritten
   *     with the Lagrange multipliers that are found during optimization.
   * @param sigma Initial penalty parameter.  This will be overwritten with the
   *     final penalty value used during optimization.
   * @param callbacks Callback functions.
   */
  template<typename LagrangianFunctionType,
           typename MatType,
           typename InVecType,
           typename GradType,
           typename... CallbackTypes>
  [[deprecated("use Optimize() with non-const lambda/sigma instead")]]
  typename std::enable_if<IsMatrixType<GradType>::value, bool>::type
  Optimize(LagrangianFunctionType& function,
           MatType& coordinates,
           const InVecType& initLambda,
           const double initSigma,
           CallbackTypes&&... callbacks)
  {
    deprecatedLambda = initLambda;
    deprecatedSigma = initSigma;
    return Optimize(function, coordinates, this->deprecatedLambda,
        this->deprecatedSigma,
        std::forward<CallbackTypes>(callbacks)...);
  }

  template<typename LagrangianFunctionType,
           typename MatType,
           typename InVecType,
           typename GradType,
           typename... CallbackTypes>
  typename std::enable_if<IsMatrixType<GradType>::value, bool>::type
  Optimize(LagrangianFunctionType& function,
           MatType& coordinates,
           InVecType& lambda,
           double& sigma,
           CallbackTypes&&... callbacks);

  //! Forward the MatType as GradType.
  template<typename LagrangianFunctionType,
           typename MatType,
           typename... CallbackTypes>
  [[deprecated("use Optimize() with non-const lambda/sigma instead")]]
  bool Optimize(LagrangianFunctionType& function,
                MatType& coordinates,
                const VecType& initLambda,
                const double initSigma,
                CallbackTypes&&... callbacks)
  {
    return Optimize<LagrangianFunctionType, MatType, MatType,
        CallbackTypes...>(function, coordinates, initLambda, initSigma,
        std::forward<CallbackTypes>(callbacks)...);
  }

  template<typename LagrangianFunctionType,
           typename MatType,
           typename InVecType,
           typename... CallbackTypes>
  bool Optimize(LagrangianFunctionType& function,
                MatType& coordinates,
                InVecType& lambda,
                double& sigma,
                CallbackTypes&&... callbacks)
  {
    return Optimize<LagrangianFunctionType, MatType, InVecType, MatType,
        CallbackTypes...>(function, coordinates, lambda, sigma,
        std::forward<CallbackTypes>(callbacks)...);
  }

  //! Get the L-BFGS object used for the actual optimization.
  const L_BFGS& LBFGS() const { return lbfgs; }
  //! Modify the L-BFGS object used for the actual optimization.
  L_BFGS& LBFGS() { return lbfgs; }

  //! Get the Lagrange multipliers.
  [[deprecated("use Optimize() with lambda/sigma parameters instead")]]
  const VecType& Lambda() const { return deprecatedLambda; }
  //! Modify the Lagrange multipliers (i.e. set them before optimization).
  [[deprecated("use Optimize() with lambda/sigma parameters instead")]]
  VecType& Lambda() { return deprecatedLambda; }

  //! Get the penalty parameter.
  [[deprecated("use Optimize() with lambda/sigma parameters instead")]]
  double Sigma() const { return deprecatedSigma; }
  //! Modify the penalty parameter.
  [[deprecated("use Optimize() with lambda/sigma parameters instead")]]
  double& Sigma() { return deprecatedSigma; }

  //! Get the maximum iterations
  size_t MaxIterations() const { return maxIterations; }
  //! Modify the maximum iterations
  size_t& MaxIterations() { return maxIterations; }

  //! Get the penalty threshold updating parameter
  double PenaltyThresholdFactor() const { return penaltyThresholdFactor; }
  //! Modify the penalty threshold updating parameter
  double& PenaltyThresholdFactor() { return penaltyThresholdFactor; }

  //! Get the sigma update factor
  double SigmaUpdateFactor() const { return sigmaUpdateFactor; }
  //! Modify the sigma update factor
  double& SigmaUpdateFactor() { return sigmaUpdateFactor; }

 private:
  //! Maximum number of iterations.
  size_t maxIterations;

  //! Parameter for updating the penalty threshold
  double penaltyThresholdFactor;

  //! Parameter for updating sigma
  double sigmaUpdateFactor;

  //! The L-BFGS optimizer that we will use.
  L_BFGS lbfgs;

  //! Controls early termination of the optimization process.
  bool terminate;

  // NOTE: these will be removed in ensmallen 4.x!
  //! Lagrange multipliers.
  VecType deprecatedLambda;
  //! Penalty parameter.
  double deprecatedSigma;

  /**
   * Internal optimization function: given an initialized AugLagrangianFunction,
   * perform the optimization itself.
   */
  template<typename LagrangianFunctionType,
           typename MatType,
           typename InVecType,
           typename GradType,
           typename... CallbackTypes>
  typename std::enable_if<IsMatrixType<GradType>::value, bool>::type
  Optimize(AugLagrangianFunction<LagrangianFunctionType, InVecType>& augfunc,
           MatType& coordinates,
           CallbackTypes&&... callbacks);

  //! Forward the MatType as GradType.
  template<typename LagrangianFunctionType,
           typename MatType,
           typename InVecType,
           typename... CallbackTypes>
  bool Optimize(
      AugLagrangianFunction<LagrangianFunctionType, InVecType>& function,
      MatType& coordinates,
      CallbackTypes&&... callbacks)
  {
    return Optimize<LagrangianFunctionType, MatType, InVecType, MatType,
        CallbackTypes...>(function, coordinates,
        std::forward<CallbackTypes>(callbacks)...);
  }
};

using AugLagrangian = AugLagrangianType<arma::vec>;

} // namespace ens

#include "aug_lagrangian_impl.hpp"

#endif // ENSMALLEN_AUG_LAGRANGIAN_AUG_LAGRANGIAN_HPP

