fx_utils#

FX-based model transformation and optimization.

Provides utilities to do FX-based model transformations.

Functions

apply_stochastic_residual

Detect and replace residual pattern with their stochastic equivalent.

count_op_instances

Counts the number of instances of op in gm.

fuse_parallel_linears

If there are parallel linears in the model, fuse them together.

replace_op

Replace a single operator, torch method or function with another.

composer.utils.fx_utils.apply_stochastic_residual(gm, drop_rate=0.2)[source]#

Detect and replace residual pattern with their stochastic equivalent.

Parameters

gm (GraphModule) โ€“ The source FX-traced graph. It can be the whole model symbolically traced.

Returns

GraphModule โ€“ Modified GraphModule that has stochastic residual connections.

composer.utils.fx_utils.count_op_instances(gm, ops)[source]#

Counts the number of instances of op in gm.

Example

>>> class M(torch.nn.Module):
...   def forward(self, x, y):
...     return x + y, torch.add(x, y), x.add(y)
>>> module = M()
>>> traced = symbolic_trace(module)
>>> count_op_instances(traced, torch.add)
1
>>> count_op_instances(traced, [operator.add, torch.add, "add"])
3
Parameters
  • module (GraphModule) โ€“ The source FX-traced graph.

  • op (Union[Callable, str, List[Union[Callable, str]]]) โ€“ The operations to count.

Returns

int โ€“ The number of instances of ops in gm

composer.utils.fx_utils.fuse_parallel_linears(gm, keep_weights=False)[source]#

If there are parallel linears in the model, fuse them together.

Example

>>> class M(nn.Module):
...   def __init__(self):
...     super().__init__()
...     self.fc1 = nn.Linear(64, 64)
...     self.fc2 = nn.Linear(64, 64)
...   def forward(self, x):
...     y = self.fc1(x)
...     z = self.fc2(x)
...     return y + z
>>> module = M()
>>> traced = symbolic_trace(module)
>>> count_op_instances(traced, nn.Linear)
2
>>> gm = fuse_parallel_linears(traced)
>>> count_op_instances(traced, nn.Linear)
1
Parameters

gm (GraphModule) โ€“ The source FX-traced graph.

Returns

GraphModule โ€“ Modified GraphModule with parallel linears fused.

composer.utils.fx_utils.replace_op(gm, src_ops, tgt_op)[source]#

Replace a single operator, torch method or function with another.

Example

>>> class M(torch.nn.Module):
...   def forward(self, x, y):
...     return x + y, torch.add(x, y), x.add(y)
>>> module = M()
>>> traced = symbolic_trace(module)
>>> traced = replace_op(traced, [operator.add, torch.add, "add"], torch.mul)
>>> count_op_instances(traced, torch.mul)
3
Parameters
  • module (GraphModule) โ€“ The source FX-traced graph.

  • src_ops (Union[Callable, str, List[Union[Callable, str]]) โ€“ Replace these operations.

  • tgt_op (Callable) โ€“ Replacement for the operations

Returns

GraphModule โ€“ Modified GraphModule with each instance of an op in src_ops replaced with tgt_op. Returns the input if no instances are found.