PyTorch 广播机制详解

  2022 年 02 月 16 日   PyTorch  

所谓“广播”,就是在两个形状不同的 Tensor 进行逐元素运算的时候,把其中一个(或两个)的形状补齐,使得两个 Tensor 形状相同,就能够运算了。

PyTorch 内置的 Tensor 类型也支持和 Numpy 的 ndarray 类型相同的广播机制。本文主要介绍其基本规则,即什么时候两个 Tensor 可以广播广播的结果是什么以及其他注意事项

一、什么时候两个 Tensor 可以广播?

当以下规则满足时,两个 Tensor 可以广播:

  1. 两个 Tensor 都至少有一个维度(都不是标量);
  2. 倒着同时遍历两个 Tensor 的形状,每个维度的大小
    1. 要么相同,
    2. 要么是 1,
    3. 要么一个维度不存在。

上面的规则也许不太好理解,直接看例子更加直接:

# 两个相同的形状,肯定可以广播
x = torch.empty(5,7,3)
y = torch.empty(5,7,3)

# 一个是标量,一个是矩阵,不可以广播
x = torch.empty((0,))
y = torch.empty(2,2)

# 倒着遍历形状,分别是 1-1(两个相同), 4-1(一个是 1), 3-3(两个相同), 5-_(一个不存在)
# 所以可以广播
x = torch.empty(5,3,4,1)
y = torch.empty(  3,1,1)

# 倒着遍历形状,遍历到 2-3 的时候不满足规则,因此无法广播
x = torch.empty(5,2,4,1)
y = torch.empty(  3,1,1)

二、如果能广播,那么广播后的形状是什么?

从上面的广播应用规则来看,就是让维度更小的 Tensor 扩充到维度更大的 Tensor 的形状上。因此,广播后结果的形状的计算规则为:

  1. 如果两个 Tensor 的维数不同,则先在维数少的 Tensor 的形状前面补 1,直到两个 Tensor 的维数相同;
  2. 然后,对于每一个维度,结果 Tensor 的维度大小就是两个 Tensor 维度大小中较大的那个。

例子:

>>> x = torch.empty(5,1,4,1)
>>> y = torch.empty(  3,1,1)
>>> (x + y).size()
torch.Size([5, 3, 4, 1])

>>> x = torch.empty(1)
>>> y = torch.empty(3,1,7)
>>> (x + y).size()
torch.Size([3, 1, 7])

>>> x = torch.empty(5,2,4,1)
>>> y = torch.empty(3,1,1)
>>> (x + y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

三、其他注意事项

对于原址(in-place)操作,如 add_() 等,PyTorch 不允许进行原址操作的 Tensor 的形状发生改变。如:

# 可以
>>> x = torch.empty(5,3,4,1)
>>> y = torch.empty(3,1,1)
>>> (x.add_(y)).size()

# 不行:x 的形状将发生改变(来迎合 y),不允许!
>>> x = torch.empty(1,3,1)
>>> y = torch.empty(3,1,7)
>>> (x.add_(y)).size()

参考