2024-03-02
高性能计算
Welford 方差计算是一种 在线(Online) 计算方差的方法。一方面,它可以在不存储所有样本的情况下,逐步计算所有样本的方差,更适合处理海量数据;另一方面,它只需要对数据进行 一次(One-pass) 遍历,能减少访存次数,提高计算性能。
三种方差计算方法
在介绍 Welford 方差计算之前,我们先来看看其他两种常见的方差计算方法。
1. Two-pass 方差计算
第一种方法是标准的方差计算公式:
$$ \text{Var} = \frac{1}{N} \sum_{i=1}^{N} (x_i - \bar{x})^2 $$
其中,$N$ 是样本数量,$\bar{x}$ 是样本均值。
为了实现这种算法,我们需要对输入数据进行两次遍历:第一次遍历计算均值 $\bar{x}$,第二次遍历计算方差 $\text{Var}$。以下为 Python 代码实现(本文都使用 Python 代替伪代码来演示计算过程,比较容易理解):
def variance_two_pass(data: np.ndarray):
N = len(data)
# one pass for mean
mean = 0
for x in data:
mean += x
mean /= N
# second pass for variance
sum_of_squares = 0
for x in data:
sum_of_squares += (x - mean) ** 2
return sum_of_squares / N
这种方法的缺陷是需要两次遍历数据,这会导致访存次数变多,计算密度较低,整个计算变成 访存受限(Memory Bound) 的。
2. One-pass 方差计算
之所以需要访问两遍数据,是因为求和公式中的每一项 $(x_i - \bar{x})^2$ 都依赖 $\bar{x}$,而 $\bar{x}$ 必须在完整地访问一次数据后才能获得。
为了解决这个问题,第二种计算方差的方法将 $\bar{x}$ 剥离出来,放到最后,因此只需要对数据进行 一次遍历。它基于以下方差的等价计算公式:
$$ \text{Var} = \frac{1}{N} \sum_{i=1}^{N} x_i^2 - \bar{x}^2 $$
这种方法的 Python 代码实现如下:
def variance_one_pass(data: np.ndarray):
N = len(data)
mean = 0
sum_of_squares = 0
for x in data:
mean += x
sum_of_squares += x ** 2
mean /= N
return sum_of_squares / N - mean ** 2
这种方法看似没有问题,但它在实际计算中可能会出现数值不稳定的情况:如果这批数据的方差很小,但每个数字都较大,此时 sum_of_squares / N
和 mean ** 2
会是 两个非常接近的大数,由于计算机浮点数只能表示稀疏的实数,此时两个数字的差也会非常不精确。在实际计算中,这种数值不稳定甚至可能会导致方差计算结果出现 负数。
3. Welford 方差计算
Welford 方差计算方法也是一种 One-pass 的算法,但它同时也能保证数值计算结果的稳定。
如果将方差计算公式中的分子 $\sum_{i=1}^{N} (x_i - \bar{x})^2$ 称为 Corrected Sum of Squares —— 每个数字关于它们均值的偏移的平方的求和(原论文中说的是 the sum of squares of the deviations of the values about their mean),那么 Welford 方法基于这样的事实:我们可以通过前 $N-1$ 个样本的 Corrected Sum of Squares,来计算前 $N$ 个样本的 Corrected Sum of Squares
观察到两者的差值:
$$ \newcommand{\mean}[2]{\bar{#1}_{#2}} $$
$$ \begin{aligned} &\sum_{i=1}^N (x_i-\mean{x}{N})^2-\sum_{i=1}^{N-1} (x_i-\mean{x}{N-1})^2 \\ &= (x_N-\mean{x}{N})^2 + \sum_{i=1}^{N-1}\left((x_i-\mean{x}{N})^2-(x_i-\mean{x}{N-1})^2\right) \\ &= (x_N-\mean{x}{N})^2 + \sum_{i=1}^{N-1}(x_i-\mean{x}{N} + x_i-\mean{x}{N-1})(\mean{x}{N-1} – \mean{x}{N}) \\ &= (x_N-\mean{x}{N})^2 + (\mean{x}{N} – x_N)(\mean{x}{N-1} – \mean{x}{N}) \\ &= (x_N-\mean{x}{N})(x_N-\mean{x}{N} – \mean{x}{N-1} + \mean{x}{N}) \\ &= (x_N-\mean{x}{N})(x_N – \mean{x}{N-1})\\ &=\Delta \end{aligned} $$
因此,如果知道了前 $N-1$ 个样本的 Corrected Sum of Squares,就能通过再加上 $\Delta$,就能获得前 $N$ 个样本的 Corrected Sum of Squares。最后再除以 $N$ 就能获得前 $N$ 个样本的方差。计算完成!
为了计算 $\Delta=(x_N-\mean{x}{N})(x_N – \mean{x}{N-1})$ 这一项,需要在数据遍历过程中维护 滚动均值(即这里的 $\mean{x}{N}$ 和 $\mean{x}{N-1}$)。这个滚动均值的计算也比较容易:假设前 $N-1$ 个数字的均值为 $\mean{x}{N-1}$,新增一个数字 $x_N$ 时,前 $N$ 个数字的均值就在原均值的基础上再增加 $\frac{x_N-\mean{x}{N-1}}{N}$。证明这个原理比较简单,在这里略去。
理解了这些原理,实现 Welford 方差算法就很简单了:
def variance_Welford(data: np.ndarray):
N = len(data)
mean = 0
s = 0
for i, x in enumerate(data):
old_mean = mean
mean += (x - mean) / (i + 1)
s += (x - mean) * (x - old_mean)
return s / N
实验验证
构造一些随机的数据,使用本文介绍的三种方差计算方法,对比其计算结果(使用 NumPy 的 np.var
作为标准答案)。
输入是一个长度为 100 的随机向量,每个元素采样自标准正态分布:
data = np.random.randn(100)
print('Variance (NumPy):', np.var(data))
print('Variantion (Two-pass):', variance_two_pass(data))
print('Variantion (One-pass):', variance_one_pass(data))
print('Variantion (Welford):', variance_Welford(data))
输出结果:
Variance (NumPy): 0.9109398020682257
Variantion (Two-pass): 0.9109398020682259
Variantion (One-pass): 0.9109398020682251
Variantion (Welford): 0.9109398020682258
此时三种版本的计算结果都还很接近。
但如果让输入是一些数值非常大的数(如 100 万附近),并且其方差很小:
data = 1_000_000_000 + np.random.randn(100)
print('Variance (NumPy):', np.var(data))
print('Variantion (Two-pass):', variance_two_pass(data))
print('Variantion (One-pass):', variance_one_pass(data))
print('Variantion (Welford):', variance_Welford(data))
输出:
Variance (NumPy): 0.8370285141526501
Variantion (Two-pass): 0.8370285141526769
Variantion (One-pass): -640.0
Variantion (Welford): 0.8370285251435098
这里就能看到,普通的 One-pass 方差计算结果非常离谱,已经变成了负数;而 Welford 算法则仍然保持与 NumPy 和 Two-pass 几乎一致。这说明 Welford 在降低访存次数的同时,也做到了数值计算的稳定性。
总结
本文介绍了 Welford 方差计算方法,它是一种在线、一次遍历的方差计算算法,能在不存储所有样本的情况下,逐步计算所有样本的方差。与传统的 Two-pass 和 One-pass 方差计算方法相比,Welford 方法在降低访存次数的同时,也做到了数值计算的稳定性。因此,Welford 方法更适合处理海量数据,也更适合在高性能计算环境中使用。
事实上,Welford 算法启发了 NVIDIA 在 2018 年提出的 Online Softmax 算法,该算法降低了 Softmax 计算的访存次数,提高了计算性能。而 Online Softmax 则直接启发了 FlashAttention,后者已经成为支撑当前最流行的 Transformer 架构的最核心的计算优化手段了。