composer.models.timm.model#

composer.models.timm.model

Classes

ComposerClassifier

Implements the base logic that all classifiers can build on top of.

Timm

A wrapper around timm.create_model() used to create a ComposerClassifier from a timm model.

Attributes

  • Optional

class composer.models.timm.model.Timm(model_name, pretrained=False, num_classes=1000, drop_rate=0.0, drop_path_rate=None, drop_block_rate=None, global_pool=None, bn_momentum=None, bn_eps=None)[source]#

Bases: composer.models.base.ComposerClassifier

A wrapper around timm.create_model() used to create a ComposerClassifier from a timm model.

Parameters
  • model_name (str) โ€“ timm model name e.g: โ€˜resnet50โ€™. A list of models can be found at https://github.com/rwightman/pytorch-image-models

  • pretrained (bool) โ€“ imagenet pretrained. default: False

  • num_classes (int) โ€“ The number of classes. Needed for classification tasks. default: 1000

  • drop_rate (float) โ€“ dropout rate. default: 0.0

  • drop_path_rate (float) โ€“ drop path rate (model default if None). default: None

  • drop_block_rate (float) โ€“ drop block rate (model default if None). default: None

  • global_pool (str) โ€“ Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None. default: None

  • bn_momentum (float) โ€“ BatchNorm momentum override (model default if not None). default: None

  • bn_eps (float) โ€“ BatchNorm epsilon override (model default if not None). default: None