composer.models.base#

composer.models.base

Classes

ComposerClassifier

Implements the base logic that all classifiers can build on top of.

ComposerModel

The minimal interface needed to use a model with composer.trainer.Trainer.

class composer.models.base.ComposerClassifier(module)[source]#

Bases: composer.models.base.ComposerModel

Implements the base logic that all classifiers can build on top of.

Inherits from ComposerModel.

Parameters

module (Module) โ€“ The neural network module to wrap with ComposerClassifier.

class composer.models.base.ComposerModel[source]#

Bases: torch.nn.modules.module.Module, abc.ABC

The minimal interface needed to use a model with composer.trainer.Trainer.

abstract loss(outputs, batch, *args, **kwargs)[source]#

Compute the loss of the model.

Parameters
  • outputs (Any) โ€“ The output of the forward pass.

  • batch (Batch) โ€“ The input batch from dataloader.

Returns

Tensors โ€“ The loss as a Tensors object.

metrics(train=False)[source]#

Get metrics for evaluating the model.

Warning

Each metric keeps states which are updated with data seen so far. As a result, different metric instances should be used for training and validation. See: https://torchmetrics.readthedocs.io/en/latest/pages/overview.html for more details.

Parameters

train (bool, optional) โ€“ True to return metrics that should be computed during training and False otherwise. (default: False)

Returns

Metrics โ€“ A Metrics object.

validate(batch)[source]#

Compute model outputs on provided data.

The output of this function will be directly used as input to all metrics returned by metrics().

Parameters

batch (Batch) โ€“ The data to perform validation with. Specified as a tuple of tensors (input, target).

Returns

Tuple[Any, Any] โ€“ Tuple that is passed directly to the update() methods of the metrics returned by metrics(). Most often, this will be a tuple of the form (predictions, targets).