Skip to content

lcskrishna/py-mi

Repository files navigation

py-mi (A Pytorch Module Instrumentor)

py-mi is a tool on top of Pytorch module (nn.Module). This tool collects the necessary statistics related to a Module created using pytorch. Also, this module doesn't use any gpu related profiling tools underneath like nvprof to collect the information. This can be used on any GPU or CPU that supports pytorch.

Note : Currently, this tool only supports FP32 computation.

Contents

Pre-requisites:

Make sure to have the following installed.

  • python 2.7 or python 3.6
  • pytorch
  • torchvision

To install pytorch and torchvision follow the below links:

  1. pytorch :Official pytorch site or Pytorch on ROCM
  2. torchvision : vision

Or use the pre-installed pytorch dockers available from their sites respectively.

Getting Started

To get started, first install the pymi module into your workspace.

To install: Clone the repository and run the following command inside the source.

python setup.py install

This will install a python module named pymi.

How to use this tool in your python scripts :

Import the following in your python scripts that you want to do get some analytics.

import pymi
from pymi import ModuleInstrumentation as mi

To get the layerwise timings profile for a particular network use the following:

net_layer_data = mi.PyModuleInstrumentation(net, input_size, iterations, is_debug).generate_layerwise_profile_info()

Here net_layer_data is a map that contains layer_type, forward_time, backward_time, layer_num.

To generate a summary into a csv file for forward and backward with configurations use the following after generating layerwise info.:

mi.PyModuleInstrumentation(net, input_size, iterations, is_debug).generate_statistics(net_layer_data, <output_prefix str>)

In the above commands,

  • net is the network created using nn.Module, for example torchvision.models.alexnet()
  • input_size is an array of input size like ([1,3,224,224])
  • iterations is the number of iterations you wish to profile.
  • is_debug is a flag to show debug mesasges (True/False)

To run a small demo, run the following example

python generate_layerwise_benchmarks.py --network alexnet