使用pytorchviz和Netron可视化pytorch网络结构

2020/10/16 17:43
阅读数 6.9K

一 使用pytorchviz可视化

 

  • 安装依赖和pytorchviz

pip install graphviz
pip install tochviz (或pip install git+https://github.com/szagoruyko/pytorchviz)

 

Graphviz 是 AT&T 开发的一款开源的图形可视化软件,可以根据dot脚本语言中绘制的无向图(显示了对象间最简单的关系)画出直观的树形图。
Graphviz在Windows中的安装需要下载Release包,并配置环境变量,否则会报错:

graphviz.backend.ExecutableNotFound: failed to execute [‘dot’, ‘-Tpng’, ‘-O’, ‘tmp’], make sure the Graphviz executables are on your systems’ PATH

 

Graphviz下载地址 https://graphviz.gitlab.io/_pages/Download/Download_windows.html

下载之后解压出来是一个“release”文件夹,把“release\bin”目录添加到系统环境变量,之后在终端中输入“dot -V”,显示以下信息表示Graphviz配置成功:

 

  • torchviz可视化torch网络结构

 
  1.  
    # Created by 牧野 CSDN
  2.  
    import torch
  3.  
    from torch import nn
  4.  
    from torchviz import make_dot, make_dot_from_trace
  5.  
     
  6.  
    model = nn.Sequential()
  7.  
    model.add_module('W0', nn.Linear(8, 16))
  8.  
    model.add_module('tanh', nn.Tanh())
  9.  
    model.add_module('W1', nn.Linear(16, 1))
  10.  
     
  11.  
    x = torch.randn(1,8)
  12.  
     
  13.  
    vis_graph = make_dot(model(x), params=dict(model.named_parameters()))
  14.  
    vis_graph.view() # 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开
  15.  
     
  16.  
    with torch.onnx.set_training(model, False):
  17.  
    trace, _ = torch.jit.get_trace_graph(model, args=(x,))
  18.  
    make_dot_from_trace(trace)

 

调用“make_dot”方法创建一个dot对象,使用“view”方法显示出来。

pytorch1.2和1.3版本中使用“torch.jit.get_trace_graph”可能会报错,1.1版本ok。

AttributeError: 'torch._C.Value' object has no attribute 'uniqueName'

 

可视化结果:

 

二 使用Netron可视化

 

Netron开源地址: https://github.com/lutzroeder/Netron
Netron的开发者是Lutz Roeder,一位来自微软Visual Studio团队的帅哥:

 

Netron是一款支持离线查看“各种”神经网络框架的模型可视化神器,其中的“各种”包括:

  1. ONNX (.onnx, .pb, .pbtxt)
  2. Keras (.h5, .keras)
  3. Core ML (.mlmodel)
  4. Caffe (.caffemodel, .prototxt)
  5. Caffe2 (predict_net.pb, predict_net.pbtxt)
  6. MXNet (.model, -symbol.json)
  7. NCNN (.param)
  8. TensorFlow Lite (.tflite)
  9. TorchScript (.pt, .pth)
  10. PyTorch (.pt, .pth)
  11. Torch (.t7)
  12. Arm NN (.armnn)
  13. BigDL (.bigdl, .model)
  14. Chainer, (.npz, .h5)
  15. CNTK (.model, .cntk)
  16. Deeplearning4j (.zip)
  17. Darknet (.cfg)
  18. ML.NET (.zip)
  19. MNN (.mnn)
  20. OpenVINO (.xml)
  21. PaddlePaddle (.zip, __model__)
  22. scikit-learn (.pkl)
  23. TensorFlow.js (model.json, .pb)
  24. TensorFlow (.pb, .meta, .pbtxt)

嗯,够多了。

Netron使用很简单,作者提供了各个平台的安装包,安装之后打开,把保存的模型文件拖入就可以了。
还以上边的模型为例,先把pytorch模型保存出来:

 
  1.  
    import torch
  2.  
    from torch import nn
  3.  
    from torchviz import make_dot, make_dot_from_trace
  4.  
     
  5.  
    model = nn.Sequential()
  6.  
    model.add_module('W0', nn.Linear(8, 16))
  7.  
    model.add_module('tanh', nn.Tanh())
  8.  
    model.add_module('W1', nn.Linear(16, 1))
  9.  
     
  10.  
    torch.save(model, 'model.pth')  # 保存模型

之后用Netron打开保存的“model.pth”:

 

网络结构很清晰,一目了然,右侧还能显示操作的进一步信息。

如果你懒得安装,还可以使用作者提供的在线Netron查看器,地址:https://lutzroeder.github.io/netron/

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部