2019/09/08 22:36

# Spatial Transformer Networks

## 网络结构

spatial transform的结构由三个部分组成，下面会详细介绍。

## 反向传播

import torch

import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F

from torchsummary import summary
class STN(nn.Module):
def __init__(self):
super(STN,self).__init__()
self.feature_extractor = vgg16(pretrained = False).features
self.conv = nn.Conv2d(512,256,7)
self.fc = nn.Sequential(
nn.Linear(256,512),
nn.ReLU(),
nn.Linear(512,6)
)

def forward(self,x):
features = self.feature_extractor(x) # (b,c,h,w) h = w = 7 c = 512
theta = self.conv(features).view(-1,256) # b,256
theta = self.fc(theta).view(-1,2,3) # b,2
transformed = F.affine_grid(theta,x.size()) # theta (n,2,3) size (n,c,h,w) ,这一步是得到仿射变换的映射
x = F.grid_sample(x,transformed) # 这一步就是根据映射关系，去做插值，得到变换后的图像
return x

if __name__ == "__main__":
net = STN()
summary(net,(3,224,224),device = "cpu")

[Running] python -u "/media/xueaoru/DATA/ubuntu/six/STN.py"
----------------------------------------------------------------
Layer (type)               Output Shape         Param #
================================================================
Conv2d-1         [-1, 64, 224, 224]           1,792
ReLU-2         [-1, 64, 224, 224]               0
Conv2d-3         [-1, 64, 224, 224]          36,928
ReLU-4         [-1, 64, 224, 224]               0
MaxPool2d-5         [-1, 64, 112, 112]               0
Conv2d-6        [-1, 128, 112, 112]          73,856
ReLU-7        [-1, 128, 112, 112]               0
Conv2d-8        [-1, 128, 112, 112]         147,584
ReLU-9        [-1, 128, 112, 112]               0
MaxPool2d-10          [-1, 128, 56, 56]               0
Conv2d-11          [-1, 256, 56, 56]         295,168
ReLU-12          [-1, 256, 56, 56]               0
Conv2d-13          [-1, 256, 56, 56]         590,080
ReLU-14          [-1, 256, 56, 56]               0
Conv2d-15          [-1, 256, 56, 56]         590,080
ReLU-16          [-1, 256, 56, 56]               0
MaxPool2d-17          [-1, 256, 28, 28]               0
Conv2d-18          [-1, 512, 28, 28]       1,180,160
ReLU-19          [-1, 512, 28, 28]               0
Conv2d-20          [-1, 512, 28, 28]       2,359,808
ReLU-21          [-1, 512, 28, 28]               0
Conv2d-22          [-1, 512, 28, 28]       2,359,808
ReLU-23          [-1, 512, 28, 28]               0
MaxPool2d-24          [-1, 512, 14, 14]               0
Conv2d-25          [-1, 512, 14, 14]       2,359,808
ReLU-26          [-1, 512, 14, 14]               0
Conv2d-27          [-1, 512, 14, 14]       2,359,808
ReLU-28          [-1, 512, 14, 14]               0
Conv2d-29          [-1, 512, 14, 14]       2,359,808
ReLU-30          [-1, 512, 14, 14]               0
MaxPool2d-31            [-1, 512, 7, 7]               0
Conv2d-32            [-1, 256, 1, 1]       6,422,784
Linear-33                  [-1, 512]         131,584
ReLU-34                  [-1, 512]               0
Linear-35                    [-1, 6]           3,078
================================================================
Total params: 21,272,134
Trainable params: 21,272,134
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.40
Params size (MB): 81.15
Estimated Total Size (MB): 300.13
----------------------------------------------------------------

[Done] exited with code=0 in 2.511 seconds



0
0 收藏

0 评论
0 收藏
0