CategoricalCrossEntropy#

class usencrypt.ai.losses.CategoricalCrossEntropy(label_smoothing=0.0, axis=0, _config=None)#

Computes the cross-entropy between labels and predictions. Is used to set up the cost function in a neural network model.

For \(c\) classes, and \(m\) ground-truth labels \(y\) and predicted labels \(\hat{y}\), this is defined as:

\[\text{CCE}(y, \hat{y}) = -\frac{1}{m} \sum^m_{i = 1} \sum_{j = 1}^c y_{ij}\log(\hat{y}_{ij})\]

Note

  • This function uses the general definition of cross-entropy, as opposed to the optimized definition of categorical cross-entropy, in order to protect the data through the entire process.

Parameters
  • label_smoothing (float) – The label smoothing parameter in the range [0, 1]. If label_smoothing > 0 then the labels are smoothed by squeezing them towards 0.5 (i.e., 1 - 0.5 * label_smoothing for the target class and 0.5 * label_smoothing for the non-target class). Defaults to 0.0.

  • axis (int) – The axis along which the mean of is computed. Defaults to 0.

Inheritance
Call arguments
Returns

The computed categorical cross-entropy of the input values.

Return type

float or usencrypt.cipher.Float

Warning

Overflow occurs around the range of x = 60 due to exponentiation occurring in the Newton-Raphson process for the usencrypt.log() function. See the warnings for usencrypt.log().

Examples

In addition to being set as the loss function of a neural network, an instance of the CategoricalCrossEntropy function can be called directly as follows:

>>> import numpy as np
>>> import usencrypt as ue
>>> y_true = np.random.rand(3, 3)
>>> y_true
array([[0.29968769, 0.05661954, 0.56347932],
       [0.13040524, 0.46424133, 0.67697876],
       [0.01753717, 0.71559143, 0.34469593]])
>>> y_pred = np.random.rand(3, 3)
>>> y_pred
array([[0.96913705, 0.09175659, 0.58708196],
       [0.53064407, 0.41190663, 0.92961167],
       [0.38228412, 0.63935748, 0.49057806]])
>>> cce = ue.ai.losses.CategoricalCrossEntropy()
>>> entropy = cce(y_true, y_pred)
>>> entropy
array([0.39195927, 1.03236508, 1.69949544])