22 if (reduction_ ==
"mean") {
23 return squared_diff.
mean();
24 }
else if (reduction_ ==
"sum") {
25 return squared_diff.
sum();
51 Variable<T> CrossEntropyLoss<T>::log_softmax(
const Variable<T>& logits) {
72 if (reduction_ ==
"mean") {
74 }
else if (reduction_ ==
"sum") {
102 Variable<T> log_one_minus_pred = (one_var - predictions).log();
104 Variable<T> loss = targets * log_pred + (one_var - targets) * log_one_minus_pred;
107 if (reduction_ ==
"mean") {
109 }
else if (reduction_ ==
"sum") {
129 return bce_loss.
forward(sigmoid_pred, targets);
146 Variable<T> margin = one_var - targets * predictions;
152 if (reduction_ ==
"mean") {
154 }
else if (reduction_ ==
"sum") {
186 if (reduction_ ==
"mean") {
188 }
else if (reduction_ ==
"sum") {
Binary Cross Entropy Loss with autograd support.
Variable< T > forward(const Variable< T > &predictions, const Variable< T > &targets) override
Forward pass: compute binary cross entropy loss.
Binary Cross Entropy with Logits Loss.
Variable< T > forward(const Variable< T > &predictions, const Variable< T > &targets) override
Forward pass: compute BCE loss from logits.
Cross Entropy Loss with autograd support.
Variable< T > forward(const Variable< T > &predictions, const Variable< T > &targets) override
Forward pass: compute cross entropy loss.
Hinge Loss with autograd support.
Variable< T > forward(const Variable< T > &predictions, const Variable< T > &targets) override
Forward pass: compute hinge loss.
Huber Loss with autograd support.
Variable< T > forward(const Variable< T > &predictions, const Variable< T > &targets) override
Forward pass: compute Huber loss.
Mean Squared Error Loss with autograd support.
Variable< T > forward(const Variable< T > &predictions, const Variable< T > &targets) override
Forward pass: compute MSE loss.
Variable class that supports automatic differentiation.
Variable< T > mean() const
Variable< T > sigmoid() const
Variable< T > log() const
Variable< T > exp() const
Variable< T > sum() const
PyTorch-like loss functions with automatic differentiation.