3#include <initializer_list>
7#include <xtensor/containers/xarray.hpp>
8#include <xtensor/core/xmath.hpp>
9#include <xtensor/generators/xrandom.hpp>
10#include <xtensor/io/xio.hpp>
11#include <xtensor/reducers/xreducer.hpp>
12#include <xtensor/views/xview.hpp>
13#include <xtensor/core/xshape.hpp>
58 Tensor() : data_(xt::xarray<T>::from_shape({0})), rows_(0), cols_(0) {}
92 Tensor(std::initializer_list<std::initializer_list<T>> list);
122 template<
typename... Args>
131 template<
typename... Args>
141 T &
at(
size_t row,
size_t col);
150 const T &
at(
size_t row,
size_t col)
const;
242 Tensor view(
const std::vector<size_t>& new_shape)
const;
290 [[nodiscard]]
size_t rows()
const {
return rows_; }
296 [[nodiscard]]
size_t cols()
const {
return cols_; }
302 [[nodiscard]]
size_t size()
const {
return rows_ * cols_; }
308 [[nodiscard]] std::tuple<size_t, size_t>
shape()
const {
return {rows_, cols_}; }
409 xt::xarray<T> &
data() {
return data_; }
415 const xt::xarray<T> &
data()
const {
return data_; }
431 template<
typename... Args>
433 return data_(indices...);
437 template<
typename... Args>
439 return data_(indices...);
friend std::ostream & operator<<(std::ostream &os, const Tensor< U > &tensor)
Tensor()
Default constructor creating an empty tensor.
size_t cols() const
Get the number of columns.
static Tensor full(const std::vector< size_t > &shape, T value)
Create a tensor filled with a specific value.
Tensor reshape(const std::vector< size_t > &new_shape) const
Reshape the tensor to new dimensions.
static Tensor ones(const std::vector< size_t > &shape)
Create a tensor filled with ones.
Tensor transpose() const
Compute the transpose of the tensor (for 2D tensors)
static Tensor from_array(const xt::xarray< T > &array)
Create a tensor from an existing xt::xarray.
static Tensor identity(size_t size)
Create an identity matrix (2D tensor)
std::tuple< size_t, size_t > shape() const
Get the shape of the matric in one step.
const T & operator()(Args... indices) const
Access tensor element at specified position (const)
size_t size() const
Get the total number of elements.
friend Tensor< U > dot(const Tensor< U > &a, const Tensor< U > &b)
static Tensor random(const std::vector< size_t > &shape)
Create a random tensor with values between 0 and 1.
const xt::xarray< T > & data() const
Get the underlying xtensor array (const)
Tensor squeeze(int axis=-1) const
Squeeze dimensions of size 1.
auto eigenvalues() const
Calculate eigenvalues of the matrix (for 2D square tensors)
friend U sum(const Tensor< U > &tensor)
Tensor view(const std::vector< size_t > &new_shape) const
Create a view of the tensor with new shape.
Tensor unsqueeze(size_t axis) const
Add a dimension of size 1.
xt::xarray< T > & data()
Get the underlying xtensor array.
size_t rows() const
Get the number of rows.
Tensor operator-(const Tensor &other) const
Tensor element-wise subtraction operator.
Tensor inverse() const
Calculate the inverse of the matrix (for 2D square tensors)
T determinant() const
Calculate the determinant of the matrix (for 2D square tensors)
Tensor operator*(const Tensor &other) const
Tensor element-wise multiplication operator.
Tensor matmul(const Tensor &other) const
Matrix multiplication operator (for 2D tensors)
static Tensor zeros(const std::vector< size_t > &shape)
Create a zero tensor.
T & at(size_t row, size_t col)
Access 2D tensor element at specified position (mutable) - backward compatibility.
friend U mean(const Tensor< U > &tensor)
T & operator()(Args... indices)
Access tensor element at specified position (mutable)
Tensor operator+(const Tensor &other) const
Tensor element-wise addition operator.
T mean(const Tensor< T > &tensor)
Calculate mean of all tensor elements.
std::ostream & operator<<(std::ostream &os, const Tensor< T > &tensor)
Output stream operator for tensor visualization.
Tensor< T > dot(const Tensor< T > &a, const Tensor< T > &b)
Compute dot product of two tensors.
T sum(const Tensor< T > &tensor)
Calculate sum of all tensor elements.