Deep Learning Algorithm Implementations 1.0.0
C++ implementations of fundamental deep learning algorithms
Loading...
Searching...
No Matches
utils::Variable< T > Class Template Reference

Variable class that supports automatic differentiation. More...

#include <autograd.hpp>

Public Member Functions

 Variable (const Matrix< T > &data, bool requires_grad=false)
 Constructor.
 
 Variable (const Matrix< T > &data, std::shared_ptr< Function< T > > grad_fn)
 Constructor with gradient function.
 
const Matrix< T > & data () const
 
Matrix< T > & data ()
 
const Matrix< T > & grad () const
 
Matrix< T > & grad ()
 
bool requires_grad () const
 
std::shared_ptr< Function< T > > grad_fn () const
 
void backward (const Matrix< T > &gradient=Matrix< T >())
 Perform backward pass.
 
void zero_grad ()
 Zero the gradients.
 
Variable< T > detach () const
 Detach from computational graph.
 
Variable< T > operator+ (const Variable< T > &other) const
 
Variable< T > operator- (const Variable< T > &other) const
 
Variable< T > operator* (const Variable< T > &other) const
 
Variable< T > dot (const Variable< T > &other) const
 
Variable< T > transpose () const
 
Variable< T > sum () const
 
Variable< T > mean () const
 
Variable< T > sigmoid () const
 
Variable< T > tanh () const
 
Variable< T > relu () const
 
Variable< T > exp () const
 
Variable< T > log () const
 
T & operator() (size_t row, size_t col)
 
const T & operator() (size_t row, size_t col) const
 
size_t rows () const
 
size_t cols () const
 

Detailed Description

template<typename T>
class utils::Variable< T >

Variable class that supports automatic differentiation.

Examples
/home/runner/work/deep-learning-algo-impls/deep-learning-algo-impls/include/ml/svm.hpp.

Definition at line 58 of file autograd.hpp.

Constructor & Destructor Documentation

◆ Variable() [1/2]

template<typename T >
utils::Variable< T >::Variable ( const Matrix< T > &  data,
bool  requires_grad = false 
)
inline

Constructor.

Parameters
dataThe matrix data
requires_gradWhether to compute gradients for this variable

Definition at line 65 of file autograd.hpp.

◆ Variable() [2/2]

template<typename T >
utils::Variable< T >::Variable ( const Matrix< T > &  data,
std::shared_ptr< Function< T > >  grad_fn 
)
inline

Constructor with gradient function.

Definition at line 75 of file autograd.hpp.

Member Function Documentation

◆ backward()

template<typename T >
void utils::Variable< T >::backward ( const Matrix< T > &  gradient = Matrix<T>())

Perform backward pass.

Parameters
gradientOptional gradient to start with

Definition at line 10 of file autograd.cpp.

◆ cols()

template<typename T >
size_t utils::Variable< T >::cols ( ) const
inline

Definition at line 133 of file autograd.hpp.

◆ data() [1/2]

template<typename T >
Matrix< T > & utils::Variable< T >::data ( )
inline

Definition at line 82 of file autograd.hpp.

◆ data() [2/2]

template<typename T >
const Matrix< T > & utils::Variable< T >::data ( ) const
inline

Definition at line 81 of file autograd.hpp.

◆ detach()

template<typename T >
Variable< T > utils::Variable< T >::detach ( ) const
inline

Detach from computational graph.

Definition at line 106 of file autograd.hpp.

◆ dot()

template<typename T >
Variable< T > utils::Variable< T >::dot ( const Variable< T > &  other) const

Definition at line 69 of file autograd.cpp.

◆ exp()

template<typename T >
Variable< T > utils::Variable< T >::exp ( ) const

Definition at line 147 of file autograd.cpp.

◆ grad() [1/2]

template<typename T >
Matrix< T > & utils::Variable< T >::grad ( )
inline

Definition at line 84 of file autograd.hpp.

◆ grad() [2/2]

template<typename T >
const Matrix< T > & utils::Variable< T >::grad ( ) const
inline

Definition at line 83 of file autograd.hpp.

◆ grad_fn()

template<typename T >
std::shared_ptr< Function< T > > utils::Variable< T >::grad_fn ( ) const
inline

Definition at line 86 of file autograd.hpp.

◆ log()

template<typename T >
Variable< T > utils::Variable< T >::log ( ) const

Definition at line 158 of file autograd.cpp.

◆ mean()

template<typename T >
Variable< T > utils::Variable< T >::mean ( ) const

Definition at line 102 of file autograd.cpp.

◆ operator()() [1/2]

template<typename T >
T & utils::Variable< T >::operator() ( size_t  row,
size_t  col 
)
inline

Definition at line 129 of file autograd.hpp.

◆ operator()() [2/2]

template<typename T >
const T & utils::Variable< T >::operator() ( size_t  row,
size_t  col 
) const
inline

Definition at line 130 of file autograd.hpp.

◆ operator*()

template<typename T >
Variable< T > utils::Variable< T >::operator* ( const Variable< T > &  other) const

Definition at line 58 of file autograd.cpp.

◆ operator+()

template<typename T >
Variable< T > utils::Variable< T >::operator+ ( const Variable< T > &  other) const

Definition at line 36 of file autograd.cpp.

◆ operator-()

template<typename T >
Variable< T > utils::Variable< T >::operator- ( const Variable< T > &  other) const

Definition at line 47 of file autograd.cpp.

◆ relu()

template<typename T >
Variable< T > utils::Variable< T >::relu ( ) const

Definition at line 136 of file autograd.cpp.

◆ requires_grad()

template<typename T >
bool utils::Variable< T >::requires_grad ( ) const
inline

Definition at line 85 of file autograd.hpp.

◆ rows()

template<typename T >
size_t utils::Variable< T >::rows ( ) const
inline

Definition at line 132 of file autograd.hpp.

◆ sigmoid()

template<typename T >
Variable< T > utils::Variable< T >::sigmoid ( ) const

Definition at line 113 of file autograd.cpp.

◆ sum()

template<typename T >
Variable< T > utils::Variable< T >::sum ( ) const

Definition at line 91 of file autograd.cpp.

◆ tanh()

template<typename T >
Variable< T > utils::Variable< T >::tanh ( ) const

Definition at line 124 of file autograd.cpp.

◆ transpose()

template<typename T >
Variable< T > utils::Variable< T >::transpose ( ) const

Definition at line 80 of file autograd.cpp.

◆ zero_grad()

template<typename T >
void utils::Variable< T >::zero_grad ( )
inline

Zero the gradients.

Definition at line 97 of file autograd.hpp.


The documentation for this class was generated from the following files: