2#include <xtensor-blas/xlinalg.hpp>
3#include <xtensor/containers/xadapt.hpp>
4#include <xtensor/generators/xrandom.hpp>
5#include <xtensor/views/xview.hpp>
13 Tensor<T>::Tensor(
const std::vector<size_t>& shape) : rows_(shape.size() > 0 ? shape[0] : 0), cols_(shape.size() > 1 ? shape[1] : 1), data_(xt::zeros<T>(shape)) {}
17 Tensor<T>::Tensor(
const std::vector<size_t>& shape, T value) : rows_(shape.size() > 0 ? shape[0] : 0), cols_(shape.size() > 1 ? shape[1] : 1), data_(xt::ones<T>(shape) * value) {}
30 auto shape = array.shape();
39 cols_ = list.begin()->size();
41 std::vector<T> flat_data;
42 flat_data.reserve(rows_ * cols_);
44 for (
const auto &row: list) {
45 std::copy(row.begin(), row.end(), std::back_inserter(flat_data));
48 data_ = xt::adapt(flat_data, {rows_, cols_});
56 return operator()(row, col);
61 return operator()(row, col);
72 if (rows_ != other.rows_ || cols_ != other.cols_) {
73 throw std::invalid_argument(
"Tensor shapes must match for addition");
77 result.data_ = data_ + other.data_;
83 if (rows_ != other.rows_ || cols_ != other.cols_) {
84 throw std::invalid_argument(
"Tensor shapes must match for subtraction");
88 result.data_ = data_ - other.data_;
94 if (rows_ != other.rows_ || cols_ != other.cols_) {
95 throw std::invalid_argument(
"Tensor shapes must match for element-wise multiplication");
99 result.data_ = data_ * other.data_;
106 result.data_ = data_ * scalar;
112 if (data_.dimension() != 2 || other.data_.dimension() != 2) {
113 throw std::invalid_argument(
"Matrix multiplication requires 2D tensors");
115 if (cols_ != other.rows_) {
116 throw std::invalid_argument(
"Tensor dimensions incompatible for matrix multiplication");
120 throw std::runtime_error(
"Matrix multiplication temporarily disabled due to xtensor-blas compatibility issue");
129 if (data_.dimension() != 2) {
130 throw std::invalid_argument(
"Transpose requires 2D tensor");
134 result.data_ = xt::transpose(data_);
140 if (axes.size() != data_.dimension()) {
141 throw std::invalid_argument(
"Number of axes must match tensor dimensions");
144 std::vector<size_t> current_shape = {rows_, cols_};
145 std::vector<size_t> new_shape(axes.size());
146 for (
size_t i = 0; i < axes.size(); ++i) {
147 new_shape[i] = current_shape[axes[i]];
151 result.data_ = xt::transpose(data_, axes);
157 size_t old_size = rows_ * cols_;
158 size_t new_size = std::accumulate(new_shape.begin(), new_shape.end(), 1UL, std::multiplies<size_t>());
160 if (old_size != new_size) {
161 throw std::invalid_argument(
"Total size must remain the same for reshape");
165 result.data_ = xt::reshape_view(data_, new_shape);
171 return reshape({new_rows, new_cols});
176 size_t old_size = rows_ * cols_;
177 size_t new_size = std::accumulate(new_shape.begin(), new_shape.end(), 1UL, std::multiplies<size_t>());
179 if (old_size != new_size) {
180 throw std::invalid_argument(
"Total size must remain the same for view");
184 result.data_ = xt::reshape_view(data_, new_shape);
190 std::vector<size_t> current_shape = {rows_, cols_};
191 std::vector<size_t> new_shape;
195 for (
size_t dim : current_shape) {
197 new_shape.push_back(dim);
202 if (axis >=
static_cast<int>(current_shape.size()) || current_shape[axis] != 1) {
203 throw std::invalid_argument(
"Cannot squeeze dimension that is not 1");
205 for (
size_t i = 0; i < current_shape.size(); ++i) {
206 if (i !=
static_cast<size_t>(axis)) {
207 new_shape.push_back(current_shape[i]);
212 if (new_shape.empty()) {
213 new_shape.push_back(1);
216 return view(new_shape);
221 std::vector<size_t> current_shape = {rows_, cols_};
222 if (axis > current_shape.size()) {
223 throw std::invalid_argument(
"Axis out of range for unsqueeze");
226 std::vector<size_t> new_shape = current_shape;
227 new_shape.insert(new_shape.begin() + axis, 1);
229 return view(new_shape);
234 if (data_.dimension() != 2 || rows_ != cols_) {
235 throw std::invalid_argument(
"Determinant requires square 2D tensor");
238 throw std::runtime_error(
"Determinant computation temporarily disabled due to xtensor-blas compatibility issue");
244 if (data_.dimension() != 2 || rows_ != cols_) {
245 throw std::invalid_argument(
"Inverse requires square 2D tensor");
249 throw std::runtime_error(
"Matrix inverse computation temporarily disabled due to xtensor-blas compatibility issue");
257 if (data_.dimension() != 2 || rows_ != cols_) {
258 throw std::invalid_argument(
"Eigenvalues require square 2D tensor");
261 throw std::runtime_error(
"Eigenvalues computation temporarily disabled due to xtensor-blas compatibility issue");
279 result.data_ = xt::ones<T>(shape);
285 return ones({rows, cols});
295 std::vector<size_t> shape = {size, size};
297 result.data_ = xt::eye<T>(size);
304 if constexpr (std::is_integral_v<T>) {
305 result.data_ = xt::random::randint<T>(shape, T(0), T(10));
307 result.data_ = xt::random::rand<T>(shape, T(0), T(1));
315 if constexpr (std::is_integral_v<T>) {
316 result.data_ = xt::random::randint<T>(shape, min, max);
318 result.data_ = xt::random::rand<T>(shape, min, max);
325 std::vector<size_t> shape = {rows, cols};
327 if constexpr (std::is_integral_v<T>) {
328 result.data_ = xt::random::randint<T>(shape, T(0), T(10));
330 result.data_ = xt::random::rand<T>(shape, T(0), T(1));
337 std::vector<size_t> shape = {rows, cols};
339 if constexpr (std::is_integral_v<T>) {
340 result.data_ = xt::random::randint<T>(shape, min, max);
342 result.data_ = xt::random::rand<T>(shape, min, max);
361 if (a.
data().dimension() != 2 || b.
data().dimension() != 2) {
362 throw std::invalid_argument(
"Dot product requires 2D tensors");
365 throw std::invalid_argument(
"Tensor dimensions incompatible for dot product");
369 throw std::runtime_error(
"Dot product computation temporarily disabled due to xtensor-blas compatibility issue");
377 return xt::sum(tensor.
data())();
382 return xt::mean(tensor.
data())();
Tensor()
Default constructor creating an empty tensor.
size_t cols() const
Get the number of columns.
static Tensor full(const std::vector< size_t > &shape, T value)
Create a tensor filled with a specific value.
Tensor reshape(const std::vector< size_t > &new_shape) const
Reshape the tensor to new dimensions.
static Tensor ones(const std::vector< size_t > &shape)
Create a tensor filled with ones.
Tensor transpose() const
Compute the transpose of the tensor (for 2D tensors)
static Tensor from_array(const xt::xarray< T > &array)
Create a tensor from an existing xt::xarray.
static Tensor identity(size_t size)
Create an identity matrix (2D tensor)
std::tuple< size_t, size_t > shape() const
Get the shape of the matric in one step.
static Tensor random(const std::vector< size_t > &shape)
Create a random tensor with values between 0 and 1.
Tensor squeeze(int axis=-1) const
Squeeze dimensions of size 1.
auto eigenvalues() const
Calculate eigenvalues of the matrix (for 2D square tensors)
Tensor view(const std::vector< size_t > &new_shape) const
Create a view of the tensor with new shape.
Tensor unsqueeze(size_t axis) const
Add a dimension of size 1.
xt::xarray< T > & data()
Get the underlying xtensor array.
size_t rows() const
Get the number of rows.
Tensor operator-(const Tensor &other) const
Tensor element-wise subtraction operator.
Tensor inverse() const
Calculate the inverse of the matrix (for 2D square tensors)
T determinant() const
Calculate the determinant of the matrix (for 2D square tensors)
Tensor operator*(const Tensor &other) const
Tensor element-wise multiplication operator.
Tensor matmul(const Tensor &other) const
Matrix multiplication operator (for 2D tensors)
static Tensor zeros(const std::vector< size_t > &shape)
Create a zero tensor.
T & at(size_t row, size_t col)
Access 2D tensor element at specified position (mutable) - backward compatibility.
Tensor operator+(const Tensor &other) const
Tensor element-wise addition operator.
template Tensor< double > dot< double >(const Tensor< double > &, const Tensor< double > &)
template std::ostream & operator<<< float >(std::ostream &, const Tensor< float > &)
template Tensor< float > dot< float >(const Tensor< float > &, const Tensor< float > &)
T mean(const Tensor< T > &tensor)
Calculate mean of all tensor elements.
std::ostream & operator<<(std::ostream &os, const Tensor< T > &tensor)
Output stream operator for tensor visualization.
template double sum< double >(const Tensor< double > &)
template double mean< double >(const Tensor< double > &)
Tensor< T > dot(const Tensor< T > &a, const Tensor< T > &b)
Compute dot product of two tensors.
T sum(const Tensor< T > &tensor)
Calculate sum of all tensor elements.
template float sum< float >(const Tensor< float > &)
template std::ostream & operator<<< double >(std::ostream &, const Tensor< double > &)
template float mean< float >(const Tensor< float > &)