Welford 方差计算:算法与实现

2024-03-02 高性能计算

Welford 方差计算是一种 在线(Online) 计算方差的方法。一方面,它可以在不存储所有样本的情况下,逐步计算所有样本的方差,更适合处理海量数据;另一方面,它只需要对数据进行 一次(One-pass) 遍历,能减少访存次数,提高计算性能。

三种方差计算方法

在介绍 Welford 方差计算之前,我们先来看看其他两种常见的方差计算方法。

1. Two-pass 方差计算

第一种方法是标准的方差计算公式

$$ \text{Var} = \frac{1}{N} \sum_{i=1}^{N} (x_i - \bar{x})^2 $$

其中,$N$ 是样本数量,$\bar{x}$ 是样本均值。

为了实现这种算法,我们需要对输入数据进行两次遍历:第一次遍历计算均值 $\bar{x}$,第二次遍历计算方差 $\text{Var}$。以下为 Python 代码实现:

def variance_two_pass(data: np.ndarray):
    N = len(data)

    # one pass for mean
    mean = 0
    for x in data:
        mean += x
    mean /= N

    # second pass for variance
    sum_of_squares = 0
    for x in data:
        sum_of_squares += (x - mean) ** 2
    
    return sum_of_squares / N

这种方法的缺陷是需要两次遍历数据,这导致 CPU 或 GPU 访存次数变多,计算性能受限于内存访问(访存相对于计算来说,是显著较慢的)。

2. One-pass 方差计算

第二种计算方差的方法只需要对数据进行 一次遍历,它基于以下方差计算的等价公式:

$$ \text{Var} = \frac{1}{N} \sum_{i=1}^{N} (x_i^2) - \bar{x}^2 $$

这种方法的 Python 代码实现如下:

def variance_one_pass(data: np.ndarray):
    N = len(data)

    mean = 0
    sum_of_squares = 0
    for x in data:
        mean += x
        sum_of_squares += x ** 2
    mean /= N

    return sum_of_squares / N - mean ** 2

这种方法看似没有问题,但它在实际计算中可能会出现数值不稳定的情况:如果这批数据的方差很小,但每个数字都较大,此时 sum_of_squares / Nmean ** 2 会是 两个非常接近的大数,由于计算机浮点数只能表示稀疏的实数,此时两个数字的差也会非常不精确。在实际计算中,这种数值不稳定甚至可能会导致方差计算结果出现 负数

3. Welford 方差计算

Welford 方差计算方法也是一种 One-pass 的算法,但它同时也能保证数值计算结果的稳定。

如果将方差计算公式中的分子 $\sum_{i=1}^{N} (x_i - \bar{x})^2$ 称为 Corrected Sum of Squares —— 每个数字关于它们均值的偏移的平方的求和(原论文中说的是 the sum of squares of the deviations of the values about their mean),那么 Welford 方法基于这样的事实:我们可以通过前 $N-1$ 个样本的 Corrected Sum of Squares,来计算前 $N$ 个样本的 Corrected Sum of Squares

观察到两者的差值:

$$ \newcommand{\mean}[2]{\bar{#1}_{#2}} $$

$$ \begin{aligned} &\sum_{i=1}^N (x_i-\mean{x}{N})^2-\sum_{i=1}^{N-1} (x_i-\mean{x}{N-1})^2 \\ &= (x_N-\mean{x}{N})^2 + \sum_{i=1}^{N-1}\left((x_i-\mean{x}{N})^2-(x_i-\mean{x}{N-1})^2\right) \\ &= (x_N-\mean{x}{N})^2 + \sum_{i=1}^{N-1}(x_i-\mean{x}{N} + x_i-\mean{x}{N-1})(\mean{x}{N-1} – \mean{x}{N}) \\ &= (x_N-\mean{x}{N})^2 + (\mean{x}{N} – x_N)(\mean{x}{N-1} – \mean{x}{N}) \\ &= (x_N-\mean{x}{N})(x_N-\mean{x}{N} – \mean{x}{N-1} + \mean{x}{N}) \\ &= (x_N-\mean{x}{N})(x_N – \mean{x}{N-1}) \end{aligned} $$

因此,如果知道了前 $N-1$ 个样本的 Corrected Sum of Squares,就能通过再加上 $(x_N - \mean{x}{N}) (x_N - x_{N+1})$ ,就能获得前 $N$ 个样本的 Corrected Sum of Squares。最后再除以 $N$ 就能获得前 $N$ 个样本的方差。计算完成!

为了计算 $(x_N-\mean{x}{N})(x_N – \mean{x}{N-1})$ 这一项,需要在数据遍历过程中维护一个滚动均值(即 $\mean{x}{N}$ 和 $\mean{x}{N-1}$)。这个滚动均值的计算也比较容易:假设前 $N-1$ 个数字的均值为 $\mean{x}{N-1}$,新增一个数字 $x_N$ 时,前 $N$ 个数字的均值就在原均值的基础上再增加 $\frac{x_N-\mean{x}{N-1}}{N}$。证明这个原理比较简单,在这里略去。

理解了这些原理,实现 Welford 方差算法就很简单了:

def variance_Welford(data: np.ndarray):
    N = len(data)

    mean = 0
    s = 0
    for i, x in enumerate(data):
        old_mean = mean
        mean += (x - mean) / (i + 1)
        s += (x - mean) * (x - old_mean)

    return s / N

实验验证

构造一些随机的数据,使用本文介绍的三种方差计算方法,对比其计算结果(使用 NumPy 的 np.var 作为标准答案)。

输入是一个长度为 100 的随机向量,每个元素采样自标准正态分布:

data = np.random.randn(100)
print('Variance (NumPy):', np.var(data))
print('Variantion (Two-pass):', variance_two_pass(data))
print('Variantion (One-pass):', variance_one_pass(data))
print('Variantion (Welford):', variance_Welford(data))

输出结果:

Variance (NumPy): 0.9109398020682257
Variantion (Two-pass): 0.9109398020682259
Variantion (One-pass): 0.9109398020682251
Variantion (Welford): 0.9109398020682258

此时三种版本的计算结果都还很接近。

但如果让输入是一些数值非常大的数(如 100 万附近),并且其方差很小:

data = 1_000_000_000 + np.random.randn(100)
print('Variance (NumPy):', np.var(data))
print('Variantion (Two-pass):', variance_two_pass(data))
print('Variantion (One-pass):', variance_one_pass(data))
print('Variantion (Welford):', variance_Welford(data))

输出:

Variance (NumPy): 0.8370285141526501
Variantion (Two-pass): 0.8370285141526769
Variantion (One-pass): -640.0
Variantion (Welford): 0.8370285251435098

这里就能看到,普通的 One-pass 方差计算结果非常离谱,已经变成了负数;而 Welford 算法则仍然保持与 NumPy 和 Two-pass 几乎一致。这说明 Welford 在降低访存次数的同时,也做到了数值计算的稳定性。

总结

本文介绍了 Welford 方差计算方法,它是一种在线、一次遍历的方差计算算法,能在不存储所有样本的情况下,逐步计算所有样本的方差。与传统的 Two-pass 和 One-pass 方差计算方法相比,Welford 方法在降低访存次数的同时,也做到了数值计算的稳定性。因此,Welford 方法更适合处理海量数据,也更适合在高性能计算环境中使用。

事实上,Welford 算法启发了 NVIDIA 在 2018 年提出的 Online Softmax 算法,该算法降低了 Softmax 计算的访存次数,提高了计算性能。

参考