1#include "../include/utils/matrix.hpp"
2#include <xtensor-blas/xlinalg.hpp>
3#include <xtensor/containers/xadapt.hpp>
11 rows_(rows), cols_(cols), data_(xt::ones<T>({
rows,
cols}) * value) {}
16 cols_ = list.begin()->size();
18 std::vector<T> flat_data;
19 flat_data.reserve(rows_ * cols_);
21 for (
const auto &row: list) {
22 std::copy(row.begin(), row.end(), std::back_inserter(flat_data));
25 data_ = xt::adapt(flat_data, {rows_, cols_});
30 if (row >= rows_ || col >= cols_) {
31 throw std::out_of_range(
"Matrix index out of bounds");
33 return data_(row, col);
38 if (row >= rows_ || col >= cols_) {
39 throw std::out_of_range(
"Matrix index out of bounds");
41 return data_(row, col);
46 if (rows_ != other.rows_ || cols_ != other.cols_) {
47 throw std::invalid_argument(
"Matrix dimensions must match for addition");
51 result.data_ = data_ + other.data_;
57 if (rows_ != other.rows_ || cols_ != other.cols_) {
58 throw std::invalid_argument(
"Matrix dimensions must match for subtraction");
62 result.data_ = data_ - other.data_;
68 if (cols_ != other.rows_) {
69 throw std::invalid_argument(
"Invalid dimensions for matrix multiplication");
73 result.data_ = xt::linalg::dot(data_, other.data_);
80 result.data_ = xt::transpose(data_);
86 if (new_rows * new_cols != rows_ * cols_) {
87 throw std::invalid_argument(
"New dimensions must have same total size");
91 result.data_ = xt::reshape_view(data_, {new_rows, new_cols});
98 throw std::invalid_argument(
"Determinant only defined for square matrices");
101 return xt::linalg::det(data_);
106 if (rows_ != cols_) {
107 throw std::invalid_argument(
"Inverse only defined for square matrices");
111 result.data_ = xt::linalg::inv(data_);
117 return xt::linalg::eigh(data_);
123 result.data_ = xt::zeros<T>({rows, cols});
130 result.data_ = xt::ones<T>({rows, cols});
137 result.data_ = xt::eye<T>(size);
144 result.data_ = xt::random::rand<T>({rows, cols}, min, max);
157 throw std::invalid_argument(
"Invalid dimensions for dot product");
161 result.data() = xt::linalg::dot(a.
data(), b.
data());
167 return xt::sum(matrix.
data())(0);
172 return xt::mean(matrix.
data())(0);
static Matrix zeros(size_t rows, size_t cols)
Create a matrix filled with zeros.
xt::xarray< T > & data()
Get the underlying xtensor array.
size_t cols() const
Get the number of columns.
Matrix()
Default constructor creating an empty matrix.
T determinant() const
Calculate the determinant of the matrix.
Matrix transpose() const
Compute the transpose of the matrix.
Matrix inverse() const
Calculate the inverse of the matrix.
T & operator()(size_t row, size_t col)
Access matrix element at specified position (mutable)
Matrix operator+(const Matrix &other) const
Matrix addition operator.
Matrix reshape(size_t new_rows, size_t new_cols) const
Reshape the matrix to new dimensions.
static Matrix random(size_t rows, size_t cols, T min, T max)
Create a matrix with random values.
static Matrix identity(size_t size)
Create an identity matrix.
static Matrix ones(size_t rows, size_t cols)
Create a matrix filled with ones.
Matrix operator-(const Matrix &other) const
Matrix subtraction operator.
Matrix operator*(const Matrix &other) const
Matrix multiplication operator.
size_t rows() const
Get the number of rows.
T sum(const Matrix< T > &matrix)
Calculate sum of all matrix elements.
T mean(const Matrix< T > &matrix)
Calculate mean of all matrix elements.
template double sum< double >(const Matrix< double > &)
template std::ostream & operator<<< double >(std::ostream &, const Matrix< double > &)
template Matrix< double > dot< double >(const Matrix< double > &, const Matrix< double > &)
std::ostream & operator<<(std::ostream &os, const Matrix< T > &matrix)
Output stream operator for matrix visualization.
template Matrix< float > dot< float >(const Matrix< float > &, const Matrix< float > &)
template double mean< double >(const Matrix< double > &)
Matrix< T > dot(const Matrix< T > &a, const Matrix< T > &b)
Compute dot product of two matrices.
template std::ostream & operator<<< float >(std::ostream &, const Matrix< float > &)
template float mean< float >(const Matrix< float > &)
template float sum< float >(const Matrix< float > &)