fx_utils#
FX-based model transformation and optimization.
Provides utilities to do FX-based model transformations.
Functions
Detect and replace residual pattern with their stochastic equivalent. |
|
Counts the number of instances of |
|
If there are parallel linears in the model, fuse them together. |
|
Replace a single operator, torch method or function with another. |
- composer.utils.fx_utils.apply_stochastic_residual(gm, drop_rate=0.2)[source]#
Detect and replace residual pattern with their stochastic equivalent.
- Parameters
gm (GraphModule) โ The source FX-traced graph. It can be the whole model symbolically traced.
- Returns
GraphModule โ Modified GraphModule that has stochastic residual connections.
- composer.utils.fx_utils.count_op_instances(gm, ops)[source]#
Counts the number of instances of
op
ingm
.Example
>>> class M(torch.nn.Module): ... def forward(self, x, y): ... return x + y, torch.add(x, y), x.add(y) >>> module = M() >>> traced = symbolic_trace(module) >>> count_op_instances(traced, torch.add) 1 >>> count_op_instances(traced, [operator.add, torch.add, "add"]) 3
- composer.utils.fx_utils.fuse_parallel_linears(gm, keep_weights=False)[source]#
If there are parallel linears in the model, fuse them together.
Example
>>> class M(nn.Module): ... def __init__(self): ... super().__init__() ... self.fc1 = nn.Linear(64, 64) ... self.fc2 = nn.Linear(64, 64) ... def forward(self, x): ... y = self.fc1(x) ... z = self.fc2(x) ... return y + z >>> module = M() >>> traced = symbolic_trace(module) >>> count_op_instances(traced, nn.Linear) 2 >>> gm = fuse_parallel_linears(traced) >>> count_op_instances(traced, nn.Linear) 1
- Parameters
gm (GraphModule) โ The source FX-traced graph.
- Returns
GraphModule โ Modified GraphModule with parallel linears fused.
- composer.utils.fx_utils.replace_op(gm, src_ops, tgt_op)[source]#
Replace a single operator, torch method or function with another.
Example
>>> class M(torch.nn.Module): ... def forward(self, x, y): ... return x + y, torch.add(x, y), x.add(y) >>> module = M() >>> traced = symbolic_trace(module) >>> traced = replace_op(traced, [operator.add, torch.add, "add"], torch.mul) >>> count_op_instances(traced, torch.mul) 3
- Parameters
- Returns
GraphModule โ Modified GraphModule with each instance of an op in
src_ops
replaced withtgt_op
. Returns the input if no instances are found.