/* matrix.hh * vim: set tw=80: * Eryn Wells */ #ifndef __BASICS_MATRIX_HH__ #define __BASICS_MATRIX_HH__ #include #include #include #include "basics/types.hh" namespace charles { namespace basics { /** * A generic, templated Matrix class taking two template parameters. `N` is the * number of rows. `M` is the number of columns. */ template struct Matrix { /** Construct an N x M matrix of zeros. */ static Matrix Zero(); /** Construct an N x M identity matrix. */ static Matrix Identity(); /** Value accessor. Get the ij'th item. */ Double& operator(uint i, uint j); /** Scalar multiplication */ Matrix operator*(const Double& lhs) const; /** Matrix multiplication */ template Matrix operator*(Matrix lhs) const; const Double* CArray() const; private: /** The matrix data, stored in row-major format. */ Double mData[N * M]; }; /** Scalar multiplication, scalar factor on the left. */ template Matrix operator*(const Double& lhs, const Matrix& rhs); /* * charles::basics::Matrix<>::Zero -- */ template Matrix Matrix::Zero() { Matrix m; bzero(m.mData, sizeof(Double) * N * M); return m; } /* * charles::basics::Matrix<>::Identity -- */ template Matrix Matrix::Identity() { static_assert(N == M, "Identity matrices must be square."); auto m = Matrix::Zero(); for (int i = 0; i < N; i++) { for (int j = 0; j < M; j++) { if (i == j) { m(i,j) = 1.0; } } } return m; } /* * charles::basics::Matrix<>::operator() -- */ template Double& Matrix::operator()(uint i, uint j) { assert(i < N && j < M); return mData[i * N + j]; } /* * charles::basics::Matrix<>::operator* -- */ template Matrix Matrix::operator*(const Double& lhs) const { Matrix result; for (int i = 0; i < N*M; i++) { result.mData = mData[i] * lhs; } return result; } /* * charles::basics::Matrix<>::operator* -- */ template template Matrix Matrix::operator*(Matrix lhs) const { Matrix result; for (int i = 0; i < N; i++) { for (int j = 0; j < P; j++) { /* Each cell is Sigma(k=0, M)(lhs[ik] * rhs[kj]) */ const int ij = i*N + j; mCells[ij] = 0.0; for (int k = 0; k < M; k++) { result.mCells[ij] += mCells[i*N + k] * rhs.mCells[k*P + j]; } } } return result; } /* * charles::basics::Matrix<>::CArray -- */ template const Double* Matrix::CArray() const { return mData; } /* * charles::basics::operator* -- */ template Matrix operator*(const Double& lhs, const Matrix& rhs) { return rhs * lhs; } } /* namespace basics */ } /* namespace charles */ #endif /* __BASICS_MATRIX_HH__ */