# Analyzer
# Overview
The Analyzer is a collection of static graph utils including Colossal-AI FX. Features include:
- MetaTensor -- enabling:
- Ahead-of-time Profiling
- Shape Propagation
- Ideal Flop Counter
- symbolic_trace()
- Robust Control-flow Tracing / Recompile
- Robust Activation Checkpoint Tracing / CodeGen
- Easy-to-define Bias-Addition Split
- symbolic_profile()
- Support ``MetaTensorMode``, where all Tensor operations are executed symbolically.
- Shape Inference Across Device and Unified ``MetaInfo``
- Ideal Flop Counter https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
# Quickstart
## Analyzer.FX
**Reference:**
https://pytorch.org/docs/stable/fx.html [[paper](https://arxiv.org/pdf/2112.08429)]
torch.FX is a toolkit for developers to use to transform nn.Module instances. FX consists of three main components: a symbolic tracer, an intermediate representation, and Python code generation. FX.Tracer hacks _\_\_torch_function\_\__ and use a Proxy object to propagate through any forward function of torch.nn.Module.
![image](https://user-images.githubusercontent.com/78588128/212531495-bbb934dd-dbbb-4578-8869-6171973f7dd8.png)
ColossalAI FX is modified from torch.FX, with the extra capability of ahead-of-time profiling enabled by the subclass of ``MetaTensor``.
### Analyzer.FX.symbolic_trace()
A drawback of the original torch.FX implementation is that it is poor at handling control flow. All control flow is not PyTorch native operands and requires actual instances that specify the branches to execute on. For example,
```python
class MyModule(nn.Module):
def forward(self, x):
if x.dim() == 3:
return x * 2 + 1
else:
return x - 5
```
The above function has the computation graph of
![image](https://user-images.githubusercontent.com/78588128/212532631-dba30734-577b-4418-8dc9-004d7983abc5.png)
However, since Proxy does not have concrete data, applying ``x.dim()`` will return nothing. In the context of the auto-parallel system, at least the control-flow dependencies for tensor shape should be removed, since any searched strategy could only auto-parallelize a specific computation graph with the same tensor shape. It is native to attach concrete data onto a Proxy, and propagate them through control flow.
![image](https://user-images.githubusercontent.com/78588128/212533403-1b620986-1c3a-420a-87c6-d08c9702135d.png)
With ``MetaTensor``, the computation during shape propagation can be virtualized. This speeds up tracing by avoiding allocating actual memory on devices.
#### Remarks
There is no free lunch for PyTorch to unify all operands in both its repo and other repos in its eco-system. For example, the einops library currently has no intention to support torch.FX (See https://github.com/arogozhnikov/einops/issues/188). To support different PyTorch-based libraries without modifying source code, good practices can be to allow users to register their implementation to substitute the functions not supported by torch.FX, or to avoid entering incompatible submodules.
### Analyzer.FX.symbolic_profile()
``symbolic_profile`` is another important feature of Colossal-AI's auto-parallel system. Profiling DNN can be costly, as you need to allocate memory and execute on real devices. However, since the profiling requirements for auto-parallel is enough if we can detect when and where the intermediate activations (i.e. Tensor) are generated, we can profile the whole procedure without actually executing it. ``symbolic_profile``, as its name infers, profiles the whole network with symbolic information only.
```python
with MetaTensorMode():
model = MyModule().cuda()
sample = torch.rand(100, 3, 224, 224).cuda()
meta_args = dict(
x = sample,
)
gm = symbolic_trace(model, meta_args=meta_args)
gm = symbolic_profile(gm, sample)
```
``symbolic_profile`` is enabled by ``ShapeProp`` and ``GraphProfile``.
#### ShapeProp
Both Tensor Parallel and Activation Checkpoint solvers need to know the shape information ahead of time. Unlike PyTorch's implementation, this ``ShapeProp`` can be executed under MetaTensorMode. With this, all the preparation for auto-parallel solvers can be done in milliseconds.
Meanwhile, it is easy to keep track of the memory usage of each node when doing shape propagation. However, the drawbacks of FX is that not every ``call_function`` saves its input for backward, and different tensor that flows within one FX.Graph can actually have the same layout. This raises problems for fine-grained profiling.
![image](https://user-images.githubusercontent.com/78588128/215312957-7eb6cbc3-61b2-49cf-95a4-6b859149eb8d.png)
To address this problem, I came up with a simulated environment enabled by ``torch.autograd.graph.saved_tensor_hooks`` and fake ``data_ptr`` (check ``_subclasses/meta_tensor.py`` for more details of ``data_ptr`` updates).
```python
class sim_env(saved_tensors_hooks):
"""
A simulation of memory allocation and deallocation in the forward pass
using ``saved_tensor_hooks``.
Attributes:
ctx (Dict[int, torch.Tensor]): A dictionary that maps the
data pointer of a tensor to the tensor itself. This is used
to track the memory allocation and deallocation.
param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
data pointer of all model parameters to the parameter itself.
This avoids overestimating the memory usage of the intermediate activations.
"""
def __init__(self, module: Optional[torch.nn.Module] = None):
super().__init__(self.pack_hook, self.unpack_hook)
self.ctx = {}
self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
def pack_hook(self, tensor: torch.Tensor):
if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
self.ctx[tensor.data_ptr()] = tensor
return tensor
def unpack_hook(self, tensor):
return tensor
```
The ``ctx`` variable will keep track of all saved tensors with a unique identifier. It is likely that ``nn.Parameter`` is also counted in the ``ctx``, which is not desired. To avoid this, we can use ``param_ctx`` to keep track of all parameters in the model. The ``buffer_ctx`` is used to keep track of all buffers in the model. The ``local_ctx`` that is attached to each ``Node`` marks the memory usage of the stage to which the node belongs. With simple ``intersect``, ``union`` and ``subtract`` operations, we can get any memory-related information. For non-profileable nodes, you might add your customized profile rules to simulate the memory allocation. If a ``Graph`` is modified with some non-PyTorch functions, such as fused operands, you can register the shape propagation rule with the decorator.
```python
@register_shape_impl(fuse_conv_bn)
def fuse_conv_bn_shape_impl(*args, **kwargs):
# infer output shape here
return torch.empty(output_shape, device=output_device)
```
An important notice is that ``ShapeProp`` will attach additional information to the graph, which will be exactly the input of ``Profiler``.
#### GraphProfiler
``GraphProfiler`` executes at the node level, and profiles both forward and backward within one node. For example, ``FlopProfiler`` will profile the forward and backward FLOPs of a node, and ``CommunicationProfiler`` will profile the forward and backward communication cost of a node. The ``GraphProfiler`` will attach the profiling results to the ``Node``. These procedures are decoupled for better extensibility.
To provide a general insight of the profiled results, you can set ``verbose=True`` to print the summary as well.
```python
model = tm.resnet18()
sample = torch.rand(100, 3, 224, 224)
meta_args = dict(x=sample)
gm = symbolic_trace(model, meta_args=meta