A templated matrix class for mathematical operations in deep learning.
A templated matrix class for mathematical operations in deep learningThis class provides a comprehensive matrix implementation with support for common mathematical operations required in deep learning algorithms including matrix multiplication, element-wise operations, transpose, and various initialization methods.
This implementation uses xtensor as the backend for efficient matrix operations.
auto identity = Matrix<float>::identity(3);
auto random_matrix = Matrix<float>::random(2, 3, 0.0f, 1.0f);
auto result = identity * random_matrix;
#pragma once
#include <initializer_list>
#include <iostream>
#include <tuple>
#include <xtensor/containers/xarray.hpp>
#include <xtensor/core/xmath.hpp>
#include <xtensor/generators/xrandom.hpp>
#include <xtensor/io/xio.hpp>
#include <xtensor/reducers/xreducer.hpp>
#include <xtensor/views/xview.hpp>
template<typename T>
class Matrix {
public:
Matrix() : rows_(0), cols_(0) {}
Matrix(std::initializer_list<std::initializer_list<T>> list);
template<typename U>
friend std::ostream &
operator<<(std::ostream &os,
const Matrix<U> &matrix);
template<typename U>
friend Matrix<U>
dot(
const Matrix<U> &a,
const Matrix<U> &b);
template<typename U>
friend U
sum(
const Matrix<U> &matrix);
template<typename U>
friend U
mean(
const Matrix<U> &matrix);
const T &
operator()(
size_t row,
size_t col)
const;
[[nodiscard]]
size_t rows()
const {
return rows_; }
[[nodiscard]]
size_t cols()
const {
return cols_; }
[[nodiscard]]
size_t size()
const {
return rows_ * cols_; }
[[nodiscard]] std::tuple<size_t, size_t>
shape()
const {
return {rows_, cols_}; }
xt::xarray<T> &
data() {
return data_; }
const xt::xarray<T> &
data()
const {
return data_; }
private:
xt::xarray<T> data_;
size_t rows_, cols_;
};
template<typename T>
std::ostream &
operator<<(std::ostream &os,
const Matrix<T> &matrix);
template<typename T>
Matrix<T>
dot(
const Matrix<T> &a,
const Matrix<T> &b);
template<typename T>
T
sum(
const Matrix<T> &matrix);
template<typename T>
T
mean(
const Matrix<T> &matrix);
}
static Matrix zeros(size_t rows, size_t cols)
Create a matrix filled with zeros.
xt::xarray< T > & data()
Get the underlying xtensor array.
size_t cols() const
Get the number of columns.
std::tuple< size_t, size_t > shape() const
Get the shape of the matric in one step.
Matrix()
Default constructor creating an empty matrix.
T determinant() const
Calculate the determinant of the matrix.
Matrix transpose() const
Compute the transpose of the matrix.
size_t size() const
Get the total number of elements.
friend std::ostream & operator<<(std::ostream &os, const Matrix< U > &matrix)
friend U sum(const Matrix< U > &matrix)
Matrix inverse() const
Calculate the inverse of the matrix.
T & operator()(size_t row, size_t col)
Access matrix element at specified position (mutable)
friend Matrix< U > dot(const Matrix< U > &a, const Matrix< U > &b)
Matrix operator+(const Matrix &other) const
Matrix addition operator.
Matrix reshape(size_t new_rows, size_t new_cols) const
Reshape the matrix to new dimensions.
static Matrix random(size_t rows, size_t cols, T min, T max)
Create a matrix with random values.
static Matrix identity(size_t size)
Create an identity matrix.
static Matrix ones(size_t rows, size_t cols)
Create a matrix filled with ones.
Matrix operator-(const Matrix &other) const
Matrix subtraction operator.
friend U mean(const Matrix< U > &matrix)
Matrix operator*(const Matrix &other) const
Matrix multiplication operator.
size_t rows() const
Get the number of rows.
T sum(const Matrix< T > &matrix)
Calculate sum of all matrix elements.
T mean(const Matrix< T > &matrix)
Calculate mean of all matrix elements.
Matrix< float > MatrixF
Single-precision floating point matrix.
Matrix< double > MatrixD
Double-precision floating point matrix.
std::ostream & operator<<(std::ostream &os, const Matrix< T > &matrix)
Output stream operator for matrix visualization.
Matrix< T > dot(const Matrix< T > &a, const Matrix< T > &b)
Compute dot product of two matrices.