composer.core.surgery

Surgery is our library for modifying models. It provides helper functions generally used by algorithms that need to modify a provided model, generally by substituting particular modules for optimized replacements.

class composer.core.surgery.ReplacementFunction(*args, **kwargs)[source]

Represents a scheme for replacing a model’s modules with other modules.

For typing reasons we represent this as a Protocol, but in practice this class only describes a function. Replacement policies return either a replacement module, or None. Return of None means that no modifications will be made.

Parameters
  • module (torch.nn.Module) – Source module

  • module_index (int) – Optionally used, the i-th instance of module class.

Returns

torch.nn.Module, optional – replacement module, or None to indicate no modification.

composer.core.surgery.count_module_instances(model: torch.nn.modules.module.Module, module_class: Type[torch.nn.modules.module.Module]) int[source]

Counts the number of instances of module_class in the model.

Example

>>> model = nn.Sequential([nn.Linear(16, 32), nn.Linear(32, 64), nn.ReLU])
>>> count_module_instances(model, nn.Linear)
2
>>> count_module_instances(model, (nn.Linear, nn.ReLU))
3
Parameters
Returns

int – The number of instances of module_class in model

composer.core.surgery.replace_module_classes(model: torch.nn.modules.module.Module, policies: Dict[Any, composer.core.surgery.ReplacementFunction], recurse_on_replacements: bool = False, indices: Optional[Dict[Any, int]] = None) List[Tuple[torch.nn.modules.module.Module, torch.nn.modules.module.Module]][source]

Modify model in-place by recursively applying replacement policies. Replacement policies are a mapping of source classes and ReplacementFunction.

Examples

The following policy:

policies = {
    nn.Conv2d: lambda x, idx: nn.Linear(16, 32),
    nn.MaxPool2d: lambda x, idx: nn.AvgPool2d(3, stride=2),
    nn.Linear: lambda x, idx: nn.Linear(16, 64) if x.in_features == 32 else None
}

will replace all convolution layers with linear layers, and all max pooling with average pooling. Linear layers will be optionally replaced depending on the number of input features.

Parameters
  • module – Model to modify.

  • policies – Mapping of source class to replacement function. The replacement may be either another module or None. If the latter, this replacement is skipped.

  • recurse_on_replacements – If true, policies will be applied to any module returned by another policy. E.g., if one replaces a Conv2d with a module containing another Conv2d, this new child Conv2d might also be replaced. This can recurse infinitely if the replacement policies are not conditioned on module properties that change over the course of the recursion.

  • indices – A dictionary mapping module types to the number of times they’ve occurred so far in the recursive traversal of model and its child modules. Allows us to pass module_index to the replacement policies, so that a policy may switch behavior on the i-th instance of the module_class. Note that these indices may not correspond to the order in which modules get called in the forward pass.

Returns

replaced_pairs – a list of pairs of (original module, replacement module), reflecting the replacements applied to module and its children.