5#include <unordered_set>
11 if (!requires_grad_) {
17 if (gradient.
rows() == 0 || gradient.
cols() == 0) {
27 auto input_grads = grad_fn_->backward(grad);
37 auto add_fn = std::make_shared<AddFunction<T>>();
38 Matrix<T> result = add_fn->forward({*
this, other});
40 if (requires_grad_ || other.requires_grad_) {
48 auto sub_fn = std::make_shared<SubFunction<T>>();
49 Matrix<T> result = sub_fn->forward({*
this, other});
51 if (requires_grad_ || other.requires_grad_) {
59 auto mul_fn = std::make_shared<MulFunction<T>>();
60 Matrix<T> result = mul_fn->forward({*
this, other});
62 if (requires_grad_ || other.requires_grad_) {
70 auto dot_fn = std::make_shared<DotFunction<T>>();
71 Matrix<T> result = dot_fn->forward({*
this, other});
73 if (requires_grad_ || other.requires_grad_) {
81 auto transpose_fn = std::make_shared<TransposeFunction<T>>();
82 Matrix<T> result = transpose_fn->forward({*
this});
92 auto sum_fn = std::make_shared<SumFunction<T>>();
93 Matrix<T> result = sum_fn->forward({*
this});
103 auto sum_result =
sum();
104 T count =
static_cast<T
>(data_.rows() * data_.cols());
114 auto sigmoid_fn = std::make_shared<SigmoidFunction<T>>();
115 Matrix<T> result = sigmoid_fn->forward({*
this});
117 if (requires_grad_) {
126 Matrix<T> result(data_.rows(), data_.cols());
127 for (
size_t i = 0; i < data_.rows(); ++i) {
128 for (
size_t j = 0; j < data_.cols(); ++j) {
129 result(i, j) = std::tanh(data_(i, j));
137 Matrix<T> result(data_.rows(), data_.cols());
138 for (
size_t i = 0; i < data_.rows(); ++i) {
139 for (
size_t j = 0; j < data_.cols(); ++j) {
140 result(i, j) = std::max(
static_cast<T
>(0), data_(i, j));
148 Matrix<T> result(data_.rows(), data_.cols());
149 for (
size_t i = 0; i < data_.rows(); ++i) {
150 for (
size_t j = 0; j < data_.cols(); ++j) {
151 result(i, j) = std::exp(data_(i, j));
159 Matrix<T> result(data_.rows(), data_.cols());
160 for (
size_t i = 0; i < data_.rows(); ++i) {
161 for (
size_t j = 0; j < data_.cols(); ++j) {
162 result(i, j) = std::log(data_(i, j));
PyTorch-like automatic differentiation engine.
Matrix multiplication function.
Function node in the computational graph.
size_t cols() const
Get the number of columns.
static Matrix ones(size_t rows, size_t cols)
Create a matrix filled with ones.
size_t rows() const
Get the number of rows.
Element-wise multiplication function.
Variable class that supports automatic differentiation.
Variable< T > mean() const
Variable< T > operator-(const Variable< T > &other) const
Variable< T > operator+(const Variable< T > &other) const
Variable< T > sigmoid() const
Variable< T > log() const
Variable< T > dot(const Variable< T > &other) const
Variable< T > exp() const
Variable< T > tanh() const
Variable< T > transpose() const
void backward(const Matrix< T > &gradient=Matrix< T >())
Perform backward pass.
Variable< T > sum() const
Variable< T > operator*(const Variable< T > &other) const
Variable< T > relu() const
T sum(const Matrix< T > &matrix)
Calculate sum of all matrix elements.