composer.utils.module_surgery#
Modify model architectures.
Algorithms, such as BlurPool
, replace model parameters in-place.
This module contains helper functions to replace parameters in Module
and
Optimizer
instances.
- composer.utils.module_surgery.ReplacementFunction#
Surgery replacement function protocol.
The function is provided with a
torch.nn.Module
and a counter for the number of instances of the module type have been seen. The function should return a replacementtorch.nn.Module
if the module type should be replaced, orNone
otherwise.
Functions
Counts the number of instances of |
|
Modify model in-place by recursively applying replacement policies. |
|
Fully replaces an optimizer's parameters. |
|
Remove |
Attributes
- composer.utils.module_surgery.count_module_instances(module, module_class)[source]#
Counts the number of instances of
module_class
inmodule
, recursively.Example
>>> from torch import nn >>> module = nn.Sequential(nn.Linear(16, 32), nn.Linear(32, 64), nn.ReLU()) >>> count_module_instances(module, nn.Linear) 2 >>> count_module_instances(module, (nn.Linear, nn.ReLU)) 3
- composer.utils.module_surgery.replace_module_classes(module, policies, optimizers=None, recurse_on_replacements=False, indices=None)[source]#
Modify model in-place by recursively applying replacement policies.
Example
The following example replaces all convolution layers with linear layers, and linear layers will be replaced if there are 16 input features. Recursion occurs on replacement.
The first replacement policy replaces the
nn.Conv2d(1, 32, 3, 1)
layer with ann.Linear(16, 32)
layer.The second replacement policy recurses on this replaced layer. Because
in_features == 16
, this policy replaces the layer with ann.Linear(32, 64)
.This policy is invoked again on this new layer. However, since
in_features == 32
, no replacement occurs and this policy returnsNone
.Since all policies do not match or now return
None
on all layers, surgery is finished.All replacements, including intermediate replacements, are returned.
>>> from torch import nn >>> module = nn.Sequential( ... nn.Conv2d(1, 32, 3, 1), ... nn.ReLU(), ... nn.MaxPool2d(2), ... nn.Flatten(), ... nn.Linear(5408, 128), ... nn.ReLU(), ... nn.LogSoftmax(dim=1), ... ) >>> policies = { ... nn.Conv2d: lambda x, idx: nn.Linear(16, 32), ... nn.Linear: lambda x, idx: nn.Linear(32, 64) if x.in_features == 16 else None ... } >>> replace_module_classes(module, policies, recurse_on_replacements=True) {Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)): Linear(in_features=16, out_features=32, bias=True), Linear(in_features=16, out_features=32, bias=True): Linear(in_features=32, out_features=64, bias=True)}
Warning
When a module is replaced, any tensor values within the module are not copied over to the new module even when the shape is identical. For example, if model weights are initialized prior to calling this function, the initialized weights will not be preserved in any replacements.
- Parameters
module (Module) โ Model to modify.
policies (Mapping[Module, ReplacementFunction]) โ Mapping of source module class to a replacement function. Matching policies are applied in the iteration order of the dictionary, so if order is important, an
OrderedDict
should be used. The replacement function may return either anotherModule
orNone
. If the latter, the source module is not replaced.recurse_on_replacements (bool) โ If true, policies will be applied to any module returned by another policy. For example, if one policy replaces a
Conv2d
with a module containing anotherConv2d
, the replacement function will be invoked with this new childConv2d
instance. If the replacement policies are not conditioned on module properties that change during replacement, infinite recursion is possible.indices (Dict[Any, int], optional) โ
A dictionary mapping module types to the number of times theyโve occurred so far in the recursive traversal of
module
and its child modules. The value is provided to replacement functions, so they may switch behaviors depending on the number of replacements that occurred for a given module type.Note
These indices may not correspond to the order in which modules get called in the forward pass.
optimizers (Optimizers, optional) โ One or more
Optimizer
objects. If provided, this function will attempt to remove parameters in replaced modules from these optimizers, and add parameters from the newly-created modules. Seeupdate_params_in_optimizer()
for more information.
- Returns
Dict[torch.nn.Module, torch.nn.Module] โ A dictionary of
{original_module: replacement_module}
reflecting the replacements applied tomodule
and its children.
- composer.utils.module_surgery.replace_params_in_optimizer(old_params, new_params, optimizers)[source]#
Fully replaces an optimizerโs parameters.
This differs from
update_params_in_optimizer()
in that this method is capable of replacing parameters spanning multiple param groups. To accomplish this, this function assumes that parameters innew_params
should inherit the param group of the corresponding parameter fromold_params
. Thus, this function also assumes thatold_params
andnew_params
have the same length.- Parameters
old_params (Iterator[Parameter]) โ Current parameters of the optimizer.
new_params (Iterator[Parameter]) โ New parameters of the optimizer, given in the same order as
old_params
. Must be the same length asold_params
.optimizers (Optimizers) โ One or more
torch.optim.Optimizer
objects.
- Raises
NotImplementedError โ If
optimizers
contains more than one optimizer.RuntimeError โ If
old_params
andnew_params
have different lengths, or if a param fromold_params
cannot be found.
- composer.utils.module_surgery.update_params_in_optimizer(old_params, new_params, optimizers)[source]#
Remove
old_params
from theoptimizers
and insertnew_params
.Newly added parameters will be added to the same
param_group
as the removed parameters. ARuntimeError
will be raised ifold_params
is split across multiple parameter groups.This function differs from
replace_params_in_optimizer()
in thatlen(old_params)
need not equallen(new_params)
. However, this function does not support replacing parameters accross multiple optimizer groups.Warning
Dynamically removing parameters from a
Optimizer
and adding parameters to an existingparam_group
s are not officially supported, so this function may fail when PyTorch is updated. The recommended practice is to instead recreate the optimizer when the parameter set changes To simply add new parameters without replacing existing ones, useadd_param_group()
.- Parameters
- Raises
NotImplementedError โ If
optimizers
contains more than one optimizer.RuntimeError โ If not all removed parameters are found in the same parameter group, or if any of them are not found at all.