#ifndef RUMBA_MATRIX_HPP
#define RUMBA_MATRIX_HPP

/**
  * \file matrix.hpp
  * \author Donovan Rebbechi, Ben Bly, Giorgio Grasso
  * Copyright Ben Bly
  * This file is released under the artistic license. See the file
  * COPYING for details.
  */

#include <rumba/point.hpp>
#include <rumba/exception.h>
#include <vector>
#include <sstream>
#include <cmath>



#define DELTA_IJ(x,y) ( ((x)==(y)) ? 1 : 0 )  


namespace RUMBA
{

class manifold_generator;

/**
  * The matrix class is a generic template class. Currently, there is only
  * one instance used in the RUMBA software: 
  * \code 
  * Matrix<Manifold<double>,Manifold<double>::iterator,manifold_generator> 
  * \endcode
  * 
  * The purpose of the Matrix class is to make it easier to write Matrix
  * operations on manifolds, by acting as an adaptor for a Manifold<double>
  * object. Since there is an underlying Manifold<double> object, the 
  * 4d Manifold structure can be preserved by various operations. So 
  * for example, SVD components can have a 4d structure similar to the
  * original data set. This is preferable to building matrix operations into
  * the manifold class, because the authors wish to adhere to the guideline 
  * that interfaces should be "minimal and complete" (Meyers), and having
  * a class that is designed to do many different things is clearly a breach
  * of this guideline.
  *
  * Since the Manifold class uses a reference counted store, ownership 
  * issues are simplified, in particular, when the constructor "copies"
  * a Manifold object, the reference counted data is not copied. This
  * has several consequences: copying matrices is efficient, the well-known
  * performance issues associated with temps created by expressions such
  * as
  * \code
  * A = B + C; // this creates a temporary
  * \endcode
  * do not arise.
  * Transposing matrices is extremely efficient, because transpose is 
  * implemented in a way that merely involves relabelling dimensions.
  * This is important, because transpose is a common operation.
  */

template<class manifold, class iterator, class generator >
class Matrix
{
public:
	manifold M;	

	//! Default constructor
	Matrix()
		:M(),Transpose(false),Rows(0),Cols(0),Begin(0)
	{}

	/**
	 * Constructor. This exposes some internals of the class, and should be
	 * viewed as an implementors constructor. In creating an instance of
	 * the template, implementors should also provide some convenience
	 * functions that invoke this constructor.
	 */
	Matrix(const manifold& M, iterator it, int r, int c,bool t = false)
		:M(M),Transpose(t) ,Rows(t?c:r),Cols(t?r:c),Begin(it) 
	{
		setSkips();
	}

	~Matrix()
	{
	}

	/**
	  * Returns the start of the matrix
	  */
	iterator begin() const
	{
		return Begin;
	}

	/**
	  * Returns the end of the matrix
	  */
	iterator end() const
	{
		return Begin + Rows * Cols;
	}

	/**
	  * Returns the number of rows in the matrix
	  */
	int rows() const
	{ 
		if (!Transpose)
			return Rows;
		else
			return Cols;
	}

	/**
	  * Returns the number of columns in the matrix
	  */
	int cols() const
	{
		if (!Transpose)
			return Cols;
		else
			return Rows;
	}

	/**
	  * Returns the transpose of the matrix
	  */
	Matrix transpose() const
	{
		Matrix M = *this;
		M.Transpose = !Transpose;
		M.setSkips();
		return M;
	}


	/**
	  * For implementors. Returns the increment in an iterator necessary to
	  * jump by one row.
	  */
	int rowSkip() const
	{
		return RowSkip;
	}
	/**
	  * For implementors. Returns the increment in an iterator necessary to
	  * jump by one column.
	  */
	int colSkip() const
	{
		return ColSkip;
	}

	/**
	  * Returns element (i,j)
	  */
	double& element(int i, int j)
	{
#ifdef DEBUG
		if ( i > rows() || j > cols() )
			throw RUMBA::Exception ("matrix extents out of range");
#endif
		iterator tmp=Begin;
		tmp += (RowSkip *i + ColSkip * j);
		return *tmp;
	}

	/**
	  * Element access for constant matrices.
	  */
	double element(int i, int j) const
	{
		iterator tmp = Begin;
		tmp += (RowSkip*i + ColSkip*j);
		return *tmp;
	}

	/**
	  * Returns a submatrix.
	  * \fn Matrix  subMatrix  (
	  	unsigned int rowstart, 
	    unsigned int rowext, 
		unsigned int colstart, 
		unsigned int colext ) const
	  * \param rowstart the first row of the sub-matrix
	  * \param rowext the number of rows in the sub-matrix
	  * \param colstart the first column in the sub-matrix
	  * \param colext the number of columns in the sub-matrix
	  */
	Matrix subMatrix (unsigned int rowstart, unsigned int rowext, unsigned int colstart, unsigned int colext ) const
	{
		if (rowstart+rowext>rows()||colstart+colext>cols())
			throw RUMBA::Exception("Matrix dimensions out of range");
		generator G;
		Matrix res( G ( rowext, colext ) );
		for ( unsigned int i =0; i < rowext; ++i )
			for ( unsigned int j = 0; j < colext; ++j )
				res.element(i,j) = element(rowstart + i, colstart + j);
		return res;
	}

	/**
	  * \fn void put(unsigned int r, unsigned int c, const Matrix& M)
	  *
	  * An inverse, in some sense, of sub-matrix. This function facilitates
	  * partial assignment, by over-writing a sub-matrix with the Matrix 
	  * argument.
	  *
	  * \param r the first row of the matrix that should be over-written
	  * \param c the first column of the matrix that should be over-written
	  * \param M the matrix that overwrites a submatrix of the object on
	  * which the method is invoked
	  */
	void put (unsigned int r, unsigned int c, const Matrix& M)
	{
		if (r+M.rows()>rows()||c+M.cols()>cols())
			throw RUMBA::Exception("Matrix dimensions out of range");
		int i_top = M.rows() +r; 
		int j_top = M.cols() +c;
		if ( i_top > rows() || j_top > cols() )
			throw RUMBA::Exception ( "matrix dimensions out of range" );
		for ( int i = r; i < i_top; ++i )
			for ( int j = c; j < j_top; ++j )
				element(i,j) = M.element(i-r,j-c);
	}



	bool Transpose;

private:
	void setSkips()
	{
		if (!Transpose)
		{
			RowSkip = 1;
			ColSkip = Rows;
		}
		else
		{
			RowSkip = Rows;
			ColSkip = 1;
		}
	}

	int Rows,Cols;
	int RowSkip, ColSkip;
	iterator Begin; // random access: operator ++,--,+=,-=
};



/**
  * Multiply two matrices. This function is not used directly, it's
  * called from operator*()
  */
template<class manifold,class iterator, class generator>
Matrix<manifold,iterator,generator>
multiply 
( 
 	const Matrix<manifold,iterator,generator>& left, 
	const Matrix<manifold,iterator,generator>& right 
)
{
	generator g;
	Matrix<manifold,iterator,generator> result = g(left,right);

	if ( left.cols() != right.rows() )
	{
		std::ostringstream s; 
		s << "Incompatible matrix dimensions: (" << left.rows() << ","<< left.cols() << ") and ("
			<< right.rows() << "," << right.cols() << ")" << std::endl;
		throw RUMBA::Exception(s.str());
	}


	for ( int i = 0; i < result.rows(); ++i )
	{
		for ( int j = 0; j < result.cols(); ++j )
		{
			result.element(i,j) = product_ij ( left,right,i,j );
		}
	}
	return result;
}

/**
  * Compute the (i,j)th element of the product matrix left*right
  */
template <class manifold,class iterator, class generator>
inline double product_ij 
( 
 const Matrix<manifold,iterator,generator> & left, 
 const Matrix<manifold,iterator,generator> & right,
 int i, int j )
{
	iterator left_pos=left.begin();
	iterator right_pos=right.begin();
	iterator left_end=left.begin(); 
	double sum=0;

	// left_pos is the last position in the left operand, *not* one past the
	// last position.


	left_pos += i * left.rowSkip();
	left_end += (left.cols() - 1) * left.colSkip(); // skip to last column
	left_end += i * left.rowSkip(); // jump to appropriate row

	right_pos += j*right.colSkip();

	// avoid going out of bounds on pointers. (Incrementing a pointer more than
	// 1 past the end of an array is illegal, *even if you don't dereference
	// that pointer*.

	while(1)
	{
		sum += (*left_pos) * (*right_pos);
		if ( left_pos >= left_end )
			break;

		left_pos += left.colSkip();
		right_pos += right.rowSkip();

	}


	return sum;
}


/**
  * A class used to perform LU decompositions. An \em LU-decomposition
  * is  a decomposition M = LU where L is lower-triangular and U is 
  * upper-triangular. LU-decomposition is useful for many things, including
  * solving equations, inverting matrices, and computing determinants.
  * One can further impose the condition that the main diagonal of L consists
  * entirely of 1s (and in this class, we do)
  *
  * This function uses an economical representation. In particular, since
  * U is upper-triangular, and L is lower triangular with 1s on the main
  * diagonal, we can overwrite elements below the main diagonal of M with L, 
  * and elements on or above the diagonal with U. So the decomposition is 
  * done in-place.
  */
template<class Matrix>
class LU_Functor
{
public:
	/**
	  * Build the LU functor. This overwrites M with the representation of
	  * LU described above.
	  */
	LU_Functor(Matrix& M) 
	: M(M), Parity(1), Permutation((size_t)M.cols())
	{
		if (M.rows() != M.cols())
			throw RUMBA::Exception("not a Square Matrix in LU");
		compute();	
	}

	/**
	  * Returns the determinant of M.
	  */
	double determinant()
	{
		double res = Parity;
		for (int i = 0; i<M.cols(); ++i )
			res *= M.element(i,i);
		return res;
	}

	/**
	  * Recompute with a different matrix
	  */
	void operator() (Matrix& X)
	{
		M=X;	
		Permutation.resize(M.rows());
		compute();
	}

	/**
	  * Solve MX = v. The input is the solution vector, which is over-written
	  * by the result vector.
	  */
	void solve(std::vector<double>&);

private:
	void compute();
	void getScaling(std::vector<double>&);
	Matrix M;
	short Parity;
	std::vector<int> Permutation;

};

/**
  * Overwrites a matrix with its inverse.
  */
template<class manifold, class iterator,class generator>
Matrix<manifold,iterator,generator>
destructive_invert (Matrix<manifold,iterator,generator> & M) 
{
	generator G;
	Matrix<manifold,iterator,generator> N = G(M);
	if (M.rows() != M.cols())
		throw RUMBA::Exception("Not a square matrix");
	RUMBA::LU_Functor<Matrix<manifold,iterator,generator> > x(M);
	if (!x.determinant())
		throw RUMBA::Exception("Singular matrix exception");


	std::vector<double> v ((size_t)M.rows());
	for ( int j=0; j<M.cols(); ++j)
	{
		for(int k =0; k<M.cols();++k) 
			v[k]=0;
		v[j]=1;
		x.solve(v);
		for (int i =0; i<M.rows(); ++i)
			N.element(i,j) = v[i];
	}
	return N;

}

/**
  * Returns the inverse of a square inverible matrix. Throws an 
  * exception if the matrix is not invertible or if the matrix is
  * not square.
  */
template<class manifold, class iterator,class generator>
Matrix<manifold,iterator,generator>
invert (const Matrix<manifold,iterator,generator> & M) 
{
	generator G;
	Matrix<manifold,iterator,generator> N = G(M);

	std::copy(M.begin(),M.end(),N.begin());

	return destructive_invert(N);
}


template<class Matrix>
void LU_Functor<Matrix>::solve(std::vector<double>& v)
{
	double sum;
	if ((int)v.size() < M.cols())
		throw RUMBA::Exception("Vector too small");

	// numerical recipes, eqn 2.3.6
	// keep it simple -- use a temp result vector.
	std::vector<double> w ( (size_t) M.cols() );

	// permute the RHS.
	for ( int i =0; i<M.cols(); ++i )
		w[i] = v[Permutation[i]];

	for (int i = 0; i < M.cols(); ++i)
	{
		sum = w[i];
		for ( int j = 0; j < i; ++j )
			sum -= M.element(i,j) * w[j];
		w[i] = sum;
	}

	// NR, 2.3.7
	for ( int i = M.cols() - 1; i >=0; --i )
	{
		sum = w[i];
		for ( int j = i+1; j<M.cols(); ++j )
			sum -= M.element(i,j) * w[j];
		w[i] = sum/M.element(i,i);
	}

	// now copy back w into v
	for (int i =0; i< M.cols(); ++i )
//		v[Permutation[i]] = w[i];
		v[i] = w[i];


}


template<class Matrix>
//void LU
//(Matrix<manifold,iterator,generator>& M)
void LU_Functor<Matrix>::compute() // new
{
	if (M.rows() != M.cols())
		throw "notASquareMatrixException(\"LU\")";

//	short Parity=1;
//	std::vector<int> Permutation  ((size_t)M.cols());
	
	double max;
	int pivot_index;
	double sum;
	double temp;
	std::vector<double> v((size_t)M.rows());
//	getScaling(v,M);
	getScaling(v);
	for (int i = 0; i < Permutation.size(); ++i )
		Permutation[i]=i;
	
	for ( int j = 0; j < M.cols(); ++j )
	{
		// 23.3.12, numerical recipes
		for ( int i = 0; i < j ; ++i )
		{
			sum = M.element(i,j);
			for ( int k = 0; k < i ; ++k )
				sum -= M.element(i,k) * M.element(k,j);
			M.element(i,j) = sum;
		}

		// 23.3.13, numerical recipes
		for ( int i = j; i < M.rows(); ++i )
		{
			// get term on rhs
			sum = M.element(i,j);
			for ( int k = 0; k < j; ++k )
				sum -= M.element(i,k) * M.element(k,j);
			M.element(i,j) = sum;	
		}

		// pivot element
		// divide by pivot. again, recall 2.3.13
		pivot_index = 0;
		max = 0;
		for ( int i = j; i < M.rows(); ++i )
		{
			if ( (temp=fabs(v[i]*M.element(i,j))) > max)
			{
				max = M.element(i,j);
				pivot_index = i;
			}
		}

//		Permutation[j] = pivot_index;

		if ( j != pivot_index )
		{
			std::swap( Permutation[j],Permutation[pivot_index] );
			rowswap ( M, pivot_index, j ); // rowswap
			Parity = - Parity; // reverse permutation sign
		}

		temp = 1 / M.element(j,j);
		for ( int i = j+1; i < M.rows(); ++i )
			M.element(i,j) *= temp;

	}	
}

// swap rows i and j
template<class manifold, class iterator, class generator>
void rowswap ( Matrix<manifold,iterator,generator>& theMatrix , int i, int j)
{
	iterator it1 = theMatrix.begin();
	iterator it2 = theMatrix.begin();
	it1 += i * theMatrix.rowSkip();
	it2 += j * theMatrix.rowSkip();
	iterator end = it1;
	end += ( theMatrix.colSkip() * ( theMatrix.cols() - 1 ) );

	while (1)
	{
		std::iter_swap(it1,it2);
		if ( it1 == end )
			break;
		it1 += theMatrix.colSkip();
		it2 += theMatrix.colSkip();
	}
}

template<class Matrix>
// void getScaling(std::vector<double>& v,Matrix<manifold,iterator,generator> theMatrix)
void LU_Functor<Matrix>::getScaling(std::vector<double>& v)
// replaced theMatrix with M.
{

	unsigned long row_max = M.rows();
	unsigned long col_max = M.cols();
	double max;
	double temp;

	if ( v.size() < row_max )
		throw RUMBA::Exception("vector is too damn small");

	for ( unsigned long i = 0; i < row_max; ++i )
	{
		max = 0;
		for ( unsigned long j = 0; j < col_max; ++j )
			if ( ( temp = fabs ( M.element(i,j) ) > max ))
				max = temp;
		if ( max == 0 )
			throw RUMBA::Exception("Singular Matrix Exception");
		v[i] = 1.0/max;
	}
}

template<class Matrix>
double trace ( const Matrix & M)
{
	unsigned long n = (M.rows() > M.cols() ) ? M.rows() : M.cols();
	double sum=0;
	for ( int i = 0; i < n; ++i )
		sum += M.element(i,i);
	return sum;
}


} // namespace RUMBA

// ---------------- PROTOTYPES


template<class manifold,class iterator, class generator>
RUMBA::Matrix<manifold,iterator,generator>
operator*
( 
 	const RUMBA::Matrix<manifold,iterator,generator>& left, 
	const RUMBA::Matrix<manifold,iterator,generator>& right 
);

template<class manifold,class iterator, class generator>
RUMBA::Matrix<manifold,iterator,generator>
operator+
( 
 	const RUMBA::Matrix<manifold,iterator,generator>& left, 
	const RUMBA::Matrix<manifold,iterator,generator>& right 
);

template<class manifold,class iterator, class generator>
RUMBA::Matrix<manifold,iterator,generator>
operator-
( 
 	const RUMBA::Matrix<manifold,iterator,generator>& left, 
	const RUMBA::Matrix<manifold,iterator,generator>& right 
);

//---------------------------------------------------------


template<class manifold,class iterator, class generator>
RUMBA::Matrix<manifold,iterator,generator>
operator*
( 
 	const RUMBA::Matrix<manifold,iterator,generator>& left, 
	const RUMBA::Matrix<manifold,iterator,generator>& right 
)
{
	if ( left.rows() == 1 && left.cols() == 1 )
		return right * left.element(0,0);
	else if ( right.rows() == 1 &&  right.cols() == 1 )
		return left * right.element(0,0);
	else
		return RUMBA::multiply(left,right);
}

template<class manifold,class iterator, class generator>
RUMBA::Matrix<manifold,iterator,generator>
operator+
( 
 	const RUMBA::Matrix<manifold,iterator,generator>& left, 
	const RUMBA::Matrix<manifold,iterator,generator>& right 
)
{
	if ( left.rows() != right.rows() || left.cols() != right.cols() )
		throw RUMBA::Exception("incompatible dimensions");

	iterator it1 = left.begin();
	iterator it2 = right.begin();
	generator G;
	RUMBA::Matrix<manifold,iterator,generator> result = G(left);  
	iterator it3 = result.begin();

	for ( int i = 0; i < result.rows(); ++i )
	{
		for (int j = 0; j < result.cols(); ++j )
		{
			*it3 = *it1 + *it2;
			it1 += left.colSkip();
			it2 += right.colSkip();
			it3 += result.colSkip();
		}
		it1 -= left.cols()*left.colSkip();
		it2 -= right.cols()*right.colSkip();
		it3 -= result.cols()*result.colSkip();
		it1 += left.rowSkip();
		it2 += right.rowSkip();
		it3 += result.rowSkip();
	}
	
	return result;
}

template<class manifold,class iterator, class generator>
RUMBA::Matrix<manifold,iterator,generator>
operator-
( 
 	const RUMBA::Matrix<manifold,iterator,generator>& left, 
	const RUMBA::Matrix<manifold,iterator,generator>& right 
)
{
	if ( left.rows() != right.rows() || left.cols() != right.cols() )
		throw RUMBA::Exception("incompatible dimensions");

	iterator it1 = left.begin();
	iterator it2 = right.begin();

	generator G;
	RUMBA::Matrix<manifold,iterator,generator> result = G(left);  
	iterator it3 = result.begin();

	for ( int i = 0; i < result.rows(); ++i )
	{
		for (int j = 0; j < result.cols(); ++j )
		{
			*it3 = *it1 - *it2;
			it1 += left.colSkip();
			it2 += right.colSkip();
			it3 += result.colSkip();
		}
		it1 -= left.cols()*left.colSkip();
		it2 -= right.cols()*right.colSkip();
		it3 -= result.cols()*result.colSkip();
		it1 += left.rowSkip();
		it2 += right.rowSkip();
		it3 += result.rowSkip();
	}
	
	return result;
}

template<class manifold,class iterator, class generator>
RUMBA::Matrix<manifold,iterator,generator>
operator*
( 	
	const RUMBA::Matrix<manifold,iterator,generator>& left, 
	double right 
)
{
	generator G;
	RUMBA::Matrix<manifold,iterator,generator> result = G(left);  
	iterator it1 = left.begin();
	iterator it2 = result.begin();

	for (  ; it1 != left.end(); ++it1,++it2 )
		*it2 = right * (*it1);
	return result;		
}

template<class manifold,class iterator, class generator>
RUMBA::Matrix<manifold,iterator,generator>
operator*
( 	
	double left,
	const RUMBA::Matrix<manifold,iterator,generator>& right 
)
{
	return right * left;
}


#endif
