HFCrossEntropy#

class composer.metrics.HFCrossEntropy(dist_sync_on_step=False)[source]#

Hugging Face compatible cross entropy loss.

Adds metric state variables:

sum_loss (float): The sum of the per-example loss in the batch. total_batches (float): The number of batches to average across.

Parameters

dist_sync_on_step (bool, optional) โ€“ Synchronize metric state across processes at each forward() before returning the value at the step. Default: False

compute()[source]#

Aggregate the state over all processes to compute the metric.

Returns

loss โ€“ The loss averaged across all batches as a Tensor.

update(output, target)[source]#

Updates the internal state with results from a new batch.

Parameters
  • output (Mapping) โ€“ The output from the model, which must contain either the Tensor or a Mapping type that contains the loss or model logits.

  • target (Tensor) โ€“ A Tensor of ground-truth values to compare against.