composer.models.base#
composer.models.base
Classes
Implements the base logic that all classifiers can build on top of. |
|
The minimal interface needed to use a model with |
- class composer.models.base.ComposerClassifier(module)[source]#
Bases:
composer.models.base.ComposerModel
Implements the base logic that all classifiers can build on top of.
Inherits from
ComposerModel
.- Parameters
module (Module) โ The neural network module to wrap with
ComposerClassifier
.
- class composer.models.base.ComposerModel[source]#
Bases:
torch.nn.modules.module.Module
,abc.ABC
The minimal interface needed to use a model with
composer.trainer.Trainer
.- abstract loss(outputs, batch, *args, **kwargs)[source]#
Compute the loss of the model.
- Parameters
outputs (Any) โ The output of the forward pass.
batch (Batch) โ The input batch from dataloader.
- Returns
Tensors โ The loss as a
Tensors
object.
- metrics(train=False)[source]#
Get metrics for evaluating the model.
Warning
Each metric keeps states which are updated with data seen so far. As a result, different metric instances should be used for training and validation. See: https://torchmetrics.readthedocs.io/en/latest/pages/overview.html for more details.
- Parameters
train (bool, optional) โ True to return metrics that should be computed during training and False otherwise. (default:
False
)- Returns
Metrics โ A
Metrics
object.
- validate(batch)[source]#
Compute model outputs on provided data.
The output of this function will be directly used as input to all metrics returned by
metrics()
.- Parameters
batch (Batch) โ The data to perform validation with. Specified as a tuple of tensors (input, target).
- Returns
Tuple[Any, Any] โ Tuple that is passed directly to the update() methods of the metrics returned by
metrics()
. Most often, this will be a tuple of the form (predictions, targets).