/* matrix.hh * vim: set tw=80: * Eryn Wells */ #ifndef __BASICS_MATRIX_HH__ #define __BASICS_MATRIX_HH__ #include #include #include #include "basics/types.hh" namespace charles { namespace basics { /** * A generic, templated Matrix class taking two template parameters. `N` is the * number of rows. `M` is the number of columns. If `M` is not specified, the * matrix will be square. */ template struct Matrix { /** Construct an N x M matrix of zeros. */ static Matrix Zero(); /** * Construct an N x M identity matrix. Identity matrices are always square. * It is a (compile time) error to call Identity on a Matrix class where * N != M. */ static Matrix Identity(); Matrix(); Matrix(const Double data[N*M]); Matrix(const Matrix& rhs); Matrix& operator=(const Matrix& rhs); bool operator==(const Matrix& rhs); bool operator!=(const Matrix& rhs); /** Value accessor. Get the ij'th item. */ Double& operator(uint i, uint j); /** Scalar multiplication */ Matrix operator*(const Double& rhs) const; /** Matrix multiplication */ template Matrix operator*(Matrix rhs) const; /** Get the underlying C array */ const Double* CArray() const; private: /** The matrix data, stored in row-major format. */ Double mData[N * M]; }; typedef Matrix<4> Matrix4; /** Scalar multiplication, scalar factor on the left. */ template Matrix operator*(const Double& lhs, const Matrix& rhs); /* * charles::basics::Matrix<>::Matrix -- */ template Matrix::Matrix() : mData() { } /* * charles::basics::Matrix<>::Matrix -- */ template Matrix::Matrix(const Double data[N*M]) { memcpy(mData, data, sizeof(Double) * N * M); } /* * charles::basics::Matrix<>::Matrix -- */ template Matrix::Matrix(const Matrix& rhs) : Matrix(rhs.mData) { } /* * charles::basics::Matrix<>::operator= -- */ template Matrix& Matrix::operator=(const Matrix& rhs) { memcpy(mData, rhs.mData, sizeof(Double) * N * M); return *this; } /* * charles::basics::Matrix<>::operator== -- */ template bool Matrix::operator==(const Matrix& rhs) const { for (int i = 0; i < N*M; i++) { if (mData[i] != rhs.mData[i]) { return false; } } return true; } /* * charles::basics::Matrix<>::operator!= -- */ template bool Matrix::operator!=(const Matrix& rhs) const { return !(*this == rhs); } /* * charles::basics::Matrix<>::Zero -- */ template Matrix Matrix::Zero() { Matrix m; bzero(m.mData, sizeof(Double) * N * M); return m; } /* * charles::basics::Matrix<>::Identity -- */ template Matrix Matrix::Identity() { static_assert(N == M, "Identity matrices must be square."); auto m = Matrix::Zero(); for (int i = 0; i < N; i++) { for (int j = 0; j < M; j++) { if (i == j) { m(i,j) = 1.0; } } } return m; } /* * charles::basics::Matrix<>::operator() -- */ template Double& Matrix::operator()(uint i, uint j) { assert(i < N && j < M); return mData[i * N + j]; } /* * charles::basics::Matrix<>::operator* -- */ template Matrix Matrix::operator*(const Double& rhs) const { Matrix result; for (int i = 0; i < N*M; i++) { result.mData = mData[i] * rhs; } return result; } /* * charles::basics::Matrix<>::operator* -- */ template template Matrix Matrix::operator*(Matrix rhs) const { Matrix result; for (int i = 0; i < N; i++) { for (int j = 0; j < P; j++) { /* Each cell is Sigma(k=0, M)(lhs[ik] * rhs[kj]) */ const int ij = i*N + j; mCells[ij] = 0.0; for (int k = 0; k < M; k++) { result.mCells[ij] += mCells[i*N + k] * rhs.mCells[k*P + j]; } } } return result; } /* * charles::basics::Matrix<>::CArray -- */ template const Double* Matrix::CArray() const { return mData; } /* * charles::basics::operator* -- */ template Matrix operator*(const Double& lhs, const Matrix& rhs) { return rhs * lhs; } } /* namespace basics */ } /* namespace charles */ #endif /* __BASICS_MATRIX_HH__ */