einops 张量操作

10/21 09:29
阅读数 253

 

pip install einops

 

from einops import rearrange, reduce, repeat # 按给出的模式重组张量

output_tensor = rearrange(input_tensor, 't b c -> b c t') # 结合重组(rearrange)和reduction操作

output_tensor = reduce(input_tensor, 'b c (h h2) (w w2) -> b h w c', 'mean', h2=2, w2=2) # 沿着某一维复制

output_tensor = repeat(input_tensor, 'h w -> h w c', c=3)

 

重新考虑和上面相同的例子:

y = x.view(x.shape[0], -1) # x: (batch, 256, 19, 19)
y = rearrange(x, 'b c h w -> b (c h w)')
  •  

第二行检查了输入数据拥有四个维度(当然你也可以指定其他数字)
这和仅仅写注释标明数据维度是很不一样的,毕竟据我们所知,注释不能运行也无法阻止错误发生

y = x.view(x.shape[0], -1) # x: (batch, 256, 19, 19)
y = rearrange(x, 'b c h w -> b (c h w)', c=256, h=19, w=19)

 

更多的检查

重新考虑和上面相同的例子:

y = x.view(x.shape[0], -1) # x: (batch, 256, 19, 19)
y = rearrange(x, 'b c h w -> b (c h w)')
  • 1
  • 2

第二行检查了输入数据拥有四个维度(当然你也可以指定其他数字)
这和仅仅写注释标明数据维度是很不一样的,毕竟据我们所知,注释不能运行也无法阻止错误发生

y = x.view(x.shape[0], -1) # x: (batch, 256, 19, 19)
y = rearrange(x, 'b c h w -> b (c h w)', c=256, h=19, w=19)
  • 1
  • 2

对输出的严格定义

下面有两种将张量深度转换为广度(depth-to-space)的方式

# depth-to-space
rearrange(x, 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2=2, w2=2)
rearrange(x, 'b c (h h2) (w w2) -> b (h2 w2 c) h w', h2=2, w2=2)
  • 1
  • 2
  • 3

并且我们至少还有其他的四种方式来进行这种“深度-广度”的转换。哪一种是被框架使用的呢?

这些细节往往会被忽略,因为一般情况下,这些做法不会有什么区别。
但是有时这些细节能有很大的影响(例如使用分组卷积的时候)。
所以你会希望可以在自己代码里讲清楚这个操作。

一致性

reduce(x, 'b c (x dx) -> b c x', 'max', dx=2)
reduce(x, 'b c (x dx) (y dy) -> b c x y', 'max', dx=2, dy=3)
reduce(x, 'b c (x dx) (y dy) (z dz)-> b c x y z', 'max', dx=2, dy=3, dz=4)
  • 1
  • 2
  • 3

上面这些例子展示了无论是几维的张量池化,我们都使用一致的操作,而不会因为张量维度的改变而有不同接口。

广度-深度 或者 深度-广度 的转化在许多框架中都有定义,那 宽度-高度 呢?

rearrange(x, 'b c h (w w2) -> b c (h w2) w', w2=2)
  • 1

与具体框架无关的行为表现

即使是很简单的函数在不同的框架里也往往有不同的写法。

y = x.flatten() # 或者 flatten(x)
  • 1

假设张量x的形状(shape)是(3,4,5),那么y的形状可能是:

  • 在numpy, cupy, chainer, pytorch中: (60,)
  • 在keras, tensorflow.layers, mxnet 和 gluon中: (3, 20)

与框架使用的具体术语无关

举个栗子:tailrepeat常常会令人困扰。当你要沿着宽度复制图片时,你要:

np.tile(image, (1, 2))    # 在numpy中
image.repeat(1, 2)        # pytorch的repeat ≈ numpy的tile
  • 1
  • 2

而使用einops的话,你甚至不需要研究要哪个维度的数据被复制了:

 

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
OSCHINA
登录后可查看更多优质内容
返回顶部
顶部