factorize_modules#
Module factorize_modules
.
Functions
Cast a value to a type. |
|
Approximates a \(K \times K\) convolution by factorizing it into a \(K \times K\) convolution with fewer channels followed by a \(1 \times 1\) convolution. |
|
Approximates a matrix by factorizing it into a product of two smaller matrices. |
|
Whether factorizing a module a given amount could possibly yield a benefit. |
Classes
Factorized replacement for |
|
Factorized replacement for |
|
Bundles tensors used by a factorized linear operator. |
Attributes
Optional
Tuple
Union
annotations
- class composer.algorithms.factorize.factorize_modules.FactorizedConv2d(in_channels, out_channels, kernel_size, latent_channels=0.25, **kwargs)[source]#
Bases:
composer.algorithms.factorize.factorize_modules._FactorizedModule
Factorized replacement for
torch.nn.Conv2d
.Splits the conv2d operation into two smaller conv2d operations, which are executed sequentially with no nonlinearity in between. This first conv2d can be thought of as projecting the feature maps into a lower-dimensional space, similar to PCA. The second produces outputs of the same shape as the unfactorized version based on the embeddings within this lower-dimensional space. Note that โdimensionalityโ here refers to the number of channels, not the spatial extent or tensor rank.
The first conv2d has a kernel size of
kernel_size
, while the second one always has a kernel size of \(1 \times 1\). For large kernel sizes, the lower-dimensional space can be nearly as large asmin(in_channels, out_channels)
and still yield a reduction in multiply-add operations. For kernels sizes of \(1 \times 1\), the break-even point is a 2x reduction in channel count, similar toFactorizedLinear
.See
factorize_conv2d()
for more details.- Parameters
in_channels (int) โ number of channels in the input image.
out_channels (int) โ number of channels produced by the convolution.
kernel_size (int | tuple) โ size of the convolving kernel.
latent_channels (int | float, optional) โ number of channels in the latent representation produced by the first small convolution. Can be specified as either an integer > 1 or as float within
[0, 1)
. In the latter case, the value is interpreted as a fraction ofmin(in_features, out_features)
for each linear module and is converted to the equivalent integer value, with a minimum of 1. Default:.25
.**kwargs โ other arguments to
torch.nn.Conv2d
are supported and will be used with the first of the two smallerConv2d
operations. However,groups > 1
anddilation > 1
are not currently supported.
- Raises
ValueError โ If
latent_channels
is not small enough for factorization to reduce the number of multiply-add operations. In this regime, factorization is both slower and less expressive than a non-factorized operation. Settinglatent_features
tomax_allowed_latent_channels()
or a smaller value is sufficient to avoid this.
- property in_channels[source]#
See
torch.nn.Conv2d
.
- property latent_channels[source]#
The number of of output channels for the first convolution, which is also the number of input channels for the second convolution.
- static max_allowed_latent_features(in_features, out_features, kernel_size)[source]#
Returns the largest latent channel count that reduces the number of multiply-adds.
- property out_channels[source]#
See
torch.nn.Conv2d
.
- class composer.algorithms.factorize.factorize_modules.FactorizedLinear(in_features, out_features, bias=True, latent_features=0.25)[source]#
Bases:
composer.algorithms.factorize.factorize_modules._FactorizedModule
Factorized replacement for
torch.nn.Linear
.Splits the linear operation into two smaller linear operations which are executed sequentially with no nonlinearity in between. This first linear operation can be thought of as projecting the inputs into a lower-dimensional space, similar to PCA. The second produces outputs of the same shape as the unfactorized version based on the embeddings within this lower-dimensional space.
If the lower-dimensional space is less than half the size of the smaller of the input and output dimensionality, this factorization can reduce the number of multiply-adds necessary to compute the output. However, because larger matrix products tend to utilize the hardware better, it may take a reduction of more than 2x to get a speedup in practice.
See
factorize_matrix()
for more details.- Parameters
in_features (int) โ Size of each input sample
out_features (int) โ size of each output sample
bias (bool, optional) โ If set to False, the layer will not learn an additive bias. Default:
True
.latent_features (int | float, optional) โ Size of the latent space. Can be specified as either an integer > 1 or as a float within
[0, 0.5)
. In the latter case, the value is interpreted as a fraction ofmin(in_features, out_features)
, and is converted to the equivalent integer value, with a minimum of 1. Default:.25
.
- Raises
ValueError โ If
latent_features
is not small enough for factorization to reduce the number of multiply-add operations. In this regime, factorization is both slower and less expressive than a non-factorized operation. Settinglatent_features < min(in_features, out_features) / 2
or usingmax_allowed_latent_features()
is sufficient to avoid this.
- property in_features[source]#
See
torch.nn.Linear
.
- property latent_features[source]#
The dimensionality of the space into which the input is projected by the first matrix in the factorization.
- static max_allowed_latent_channels(in_features, out_features)[source]#
Returns the largest latent feature count that reduces the number of multiply-adds.
- property out_features[source]#
See
torch.nn.Linear
.
- composer.algorithms.factorize.factorize_modules.factorizing_could_speedup(module, latent_size)[source]#
Whether factorizing a module a given amount could possibly yield a benefit.
This computation is based on the number of multiply-add operations involved in the moduleโs current forward pass versus the number that would be involved if it were factorized into two modules using the specified latent size. The operations are assumed to be dense and of the same data type in all cases.
Note that this function returning true does not guarantee a wall-clock speedup, since splitting one operation into two involves more data movement and more per-op overhead.
- Parameters
module (Module) โ A
torch.nn.Conv2d
,torch.nn.Linear
,FactorizedConv2d
, orFactorizedLinear
.latent_size (int | float) โ number of channels (for convolution) or features (for linear) in the latent representation. Can be specified as either an integer > 1 or as float within
[0, 1)
. In the latter case, the value is interpreted as a fraction ofmin(in_features, out_features)
for a linear module ormin(in_channels, out_channels)
for a convolution.
- Returns
bool โ A
bool
indicating whether the provided amount of factorization could accelerate the provided module. Ifmodule
is not one of the allowed types, always returnsFalse
, since there is no supported way to factorize that module.