Debug 故事 01

警告:为方便阅读,本文部分代码实为伪代码,请勿当真。

背景

最近我用 espnet 跑了一个比较慢的基线实验,采用 SNR loss 进行训练,虽然用上了单机 4 卡(2080ti)依然需要一周左右的时间才能完成。在基线的基础上,我希望调整一下模型训练过程,以便进行后续拓展:

  • 在计算 SNR loss 前对当前 batch 的标签进行判断,如果存在全零项,就用 L1 loss 函数来代替,因为 SNR loss 在标签为全零时无法计算。

于是,我对原先的 loss 计算代码

loss, stats = 0.0, {}
...
loss_ = snr_loss(speech_ref, speech_pre)
stats_ = {"snr_loss": loss_}
loss += loss_
stats.update(stats_)

做了如下修改:

loss, stats = 0.0, {}
...
has_all_zero = torch.cat(speech_ref, dim=0).sum(dim=-1).eq(0).any()
if has_all_zero:
    loss_ = l1_loss(speech_ref, speech_pre)
    stats_ = {"l1_loss": loss_}
else:
    loss_ = snr_loss(speech_ref, speech_pre)
    stats_ = {"snr_loss": loss_}
loss += loss_
stats.update(stats_)

Bug?

上面的修改看似应该没有任何问题,于是我开始了模型训练。然而在检查模型的训练日志时,我发现了一个令人震惊的事实,那就是日志显示模型的 l1_loss 始终都是负的!这在理论上是绝不可能的,因为 L1 loss 是绝对值。

Debug

考虑到日志显示的结果是对若干步训练结果的平均(1 个 iteration 大概包括 100 个 batch),于是我首先在上面修改的代码中插入了 print 语句来检验每个具体 batch 计算出的 L1 loss。结果也如预期般的正常,每一个打印出来的 L1 loss 值都是正的。这就与日志中的结果相矛盾了。

无端联想

此时我心中有一个巨大的疑惑:后续的代码中难道有任何操作会修改 stats 中的 loss 值吗?

乍一看似乎找不到这种操作,我只好从写日志的函数着手开始分析。因为最终写到训练日志中的 l1_loss 是不正常的,如果写日志的函数的输入输出值对应不上的话,很有可能就是它的问题!

果然

经过一番跳转可以找到,espnet 中模型训练的日志是通过类别 Reporter 的对象来输出的,它的 log_message 方法就是用来生成写到文件的日志信息。于是我在其中插入了一些 print 语句来探查输入的原始数据的值,得到的是如下的一串数字的列表:

l1_loss=[WeightedAverage(value=nan, weight=0), WeightedAverage(value=-0.835070
    788860321,weight=8), WeightedAverage(value=0.017597751691937447, weight=8),
    WeightedAverage(value=nan, weight=0), WeightedAverage(value=nan, weight=0),
    WeightedAverage(value=nan, weight=0), WeightedAverage(value=nan, weight=0),
    WeightedAverage(value=-1.290090799331665, weight=8), WeightedAverage(value=
    nan, weight=0), WeightedAverage(value=nan, weight=0), WeightedAverage(value=
    nan, weight=0), WeightedAverage(value=nan, weight=0), WeightedAverage(value=
    nan, weight=0), WeightedAverage(value=nan, weight=0), WeightedAverage(value=
    nan, weight=0), WeightedAverage(value=nan, weight=0), WeightedAverage(value=
    nan, weight=0), WeightedAverage(value=nan, weight=0), WeightedAverage(value=
    nan, weight=0), WeightedAverage(value=-0.5981062650680542, weight=8),
    ...
    WeightedAverage(value=0.014390362426638603, weight=8)]

列表里的每个 WeightedAverage 对象的 value 就对应着一个 batch 内算出的 loss 平均值,weight 则代表该 batch 的 loss 权重(实际上等于 batch_size)。

最终日志输出的 l1_loss 标量值就是这些列表元素按各自权重的加权和,如果遇到 loss 值是 nan 则跳过。

列表里的那些权重为 0 的 nan 值实际上是代码添加的,目的是将 stats 字典的所有 key 对应的 value(都是列表)长度补齐到等长,即保证每个 batch 都有与之对应的一项数值。

从上面的例子可以看出,日志系统在拿到“原始数据”的时候看到的已经是有正有负的 L1 loss 值了,因此 bug 应该在这一步之前就已经出现。

What?

那么,究竟是哪一步操作导致 stats 中的 l1_loss 如此奇怪呢?

(其实原因早已经在前文中提到了……在继续往下翻之前,你不妨大胆猜想一下)

优雅的分割线

检查完了 loss 计算和写日志的代码,夹在它们中间的就只剩一些零碎的操作了,比如反向传播、优化器迭代、学习率调整等等。这些显然和 stats 的修改没有关系。

难道遗漏了什么?为了不错过任何一点蛛丝马迹,我打算从 loss 计算结束开始,逐行进行调查。很快啊,一个以 stats 作为输入和输出的函数引起了我的注意。

if ngpu > 1 or distributed:
    # Apply weighted averaging for loss and stats
    loss = (loss * weight.type(loss.dtype)).sum()

    # if distributed, this method can also apply all_reduce()
    stats, weight = recursive_average(stats, weight, distributed)

    # Now weight is summation over all workers
    loss /= weight

这个 recursive_average 函数的特别之处在于,它的最后一个参数 distributed 意味着它和多机/多卡的运行环境有着紧密关系。

进一步跳转到它的定义

def recursive_average(obj, weight: torch.Tensor, distributed: bool = False):
    obj = recursive_sum(obj, weight, distributed)
    weight = weight.sum()
    if distributed:
        torch.distributed.all_reduce(weight, op=ReduceOp.SUM)
    # Normalize weight to be sum-to-1
    obj = recursive_divide(obj, weight)
    return obj, weight

就可以看到它的功能是对 stats 中所有的 value 列表的每一个元素都进行某种平均。可是对一个标量元素取平均有什么意义?

==

这个函数针对的场景如果是多机/多卡的话……

看着 recursive_average, recursive_sum, recursive_divide 函数中都出现的 torch.distributed.all_reduce 函数,我突然意识到,这些操作其实是对不同 GPU 上的对应数据(loss 值)进行聚合,那么平均就具有意义了。

同时,看到 all_reduce 这个熟悉的名词,我不禁回想起三年前《并行计算与并行算法》这门课上的学到的一些概念……我开始有了一些想法。

为了验证这个想法,我首先把实验设置换成了单卡,重新开始一次训练。几分钟后,我打开日志文件瞧了瞧,果然所有的 l1_loss 都恢复成了预期的正值。

果然

有了实验的验证,我开始按照新的想法修改起最初的代码:

loss, stats = 0.0, {}
...
has_all_zero = torch.cat(speech_ref, dim=0).sum(dim=-1).eq(0).any()
if has_all_zero:
    loss_l1 = l1_loss(speech_ref, speech_pre)
    loss_snr = loss_l1.new_zeros(())
    stats_ = {"l1_loss": loss_l1, "snr_loss": loss_snr}
else:
    loss_snr = snr_loss(speech_ref, speech_pre)
    loss_l1 = loss_snr.new_zeros(())
    stats_ = {"l1_loss": loss_l1, "snr_loss": loss_snr}
loss += loss_snr + loss_l1
stats.update(stats_)

再将实验设置改回 4 卡,重新进行训练。随着第一条 loss 日志的出现,我打消了心中的疑虑,同时也暗暗庆幸我当初决定采用 L1 loss 作为备选,而不是类似 SNR 的没有数值边界的 loss,不然这个 bug 恐怕要以更加诡异隐秘的形式出现了。

揭开面纱

相信熟悉并行编程的读者此时已经猜到了这个 bug 的成因。

没错,就是 AllReduce 的通信!(对 AllReduce 不熟悉的读者可以阅读 MPI 相关部分的教程:点击此处

MPI_AllReduce

当一个 GPU 上的进程运行到 torch.distributed.all_reduce 这个语句的时候,就会向所有其他 GPU 上的进程广播它所拥有的一份数据(用于聚合),并且该进程会进入阻塞状态,直到接收到来自所有其他 GPU 上进程的一份数据,完成 reduce 的操作,才会继续运行后面的语句。

而在我修改的初版代码中隐藏的问题是,不同 GPU 上处理完一个 batch 之后得到的 stats 字典中的每个键下面的值列表长度可能是不一致的,比如

GPU 0:
    stats = {
        "l1_loss": [tensor(0.057)],
        "snr_loss": [tensor(-1.052), tensor(-3.037), tensor(-2.041)],
    }

GPU 1:
    stats = {
        "l1_loss": [tensor(0.044), tensor(0.023)],
        "snr_loss": [tensor(-2.092), tensor(-1.899)],
    }

由于不同 GPU 上所有值列表长度之和是相等的(等于 batch 数量),在运行到 recursive_average 对不同 GPU 上 stats 进行聚合时能正常完成,不会出现数据数量不一致的错误。

但是,这并不意味着聚合的过程是正确的,因为如前面所提到的,只要运行到 torch.distributed.all_reduce 这个语句,当前 GPU 上的进程就会阻塞并开始接收来自其他 GPU 上的进程的数据,但它并不关心接收到的数据是不是属于 stats 字典中同一个键的。

为了说清楚这一点,我们可以看一下 recursive_sum 这个函数的具体定义,因为在 recursive_average 中首先就调用了该函数来聚合 stats 字典中的数据:

def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False):
    assert weight.dim() == 1, weight.size()
    if isinstance(obj, (tuple, list)):
        return type(obj)(recursive_sum(v, weight, distributed) for v in obj)
    elif isinstance(obj, dict):
        return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()}
    elif isinstance(obj, torch.Tensor):
        assert obj.size() == weight.size(), (obj.size(), weight.size())
        obj = (obj * weight.type(obj.dtype)).sum()
        if distributed:
            torch.distributed.all_reduce(obj, op=ReduceOp.SUM)
        return obj
    elif obj is None:
        return None
    else:
        raise ValueError(type(obj))

虽然 recursive_sum 有很多 if 分支,但可以发现,不管输入数据是哪种合法的数据类型,最终都会递归地跳转到 isinstance(obj, torch.Tensor) 的分支,并运行 torch.distributed.all_reduce 语句进行数据聚合。

以 GPU 0 上的 stats 字典的 recursive_sum 处理过程为例,它的迭代过程如下:

"""
stats = {
    "l1_loss": [tensor(0.057)],
    "snr_loss": [tensor(-1.052), tensor(-3.037), tensor(-2.041)],
}
"""

(1)调用 recursive_sum(stats, weight, distributed=true)
↳ 进入 isinstance(obj, dict) 分支,依次执行 (2) 和 (4)

    # "l1_loss"
    (2)调用 recursive_sum([tensor(0.057)], weight, distributed=true)
    ↳ 进入 isinstance(obj, (tuple, list)) 分支,执行 (3)

        (3)调用 recursive_sum(tensor(0.057), weight, distributed=true)
        ↳ 进入 isinstance(obj, torch.Tensor) 分支
        ↳ 执行 torch.distributed.all_reduce 语句

    # "snr_loss"
    (4)调用 recursive_sum([tensor(-1.052), tensor(-3.037), tensor(-2.041)], weight, distributed=true)
    ↳ 进入 isinstance(obj, (tuple, list)) 分支,依次执行 (5), (6) 和 (7)

        (5)调用 recursive_sum(tensor(-1.052), weight, distributed=true)
        ↳ 进入 isinstance(obj, torch.Tensor) 分支
        ↳ 执行 torch.distributed.all_reduce 语句

        (6)调用 recursive_sum(tensor(-3.037), weight, distributed=true)
        ↳ 进入 isinstance(obj, torch.Tensor) 分支
        ↳ 执行 torch.distributed.all_reduce 语句

        (7)调用 recursive_sum(tensor(-2.041), weight, distributed=true)
        ↳ 进入 isinstance(obj, torch.Tensor) 分支
        ↳ 执行 torch.distributed.all_reduce 语句

同理,GPU 1 上的 stats 字典的 recursive_sum 迭代过程如下:

"""
stats = {
    "l1_loss": [tensor(0.044), tensor(0.023)],
    "snr_loss": [tensor(-2.092), tensor(-1.899)],
}
"""

(1)调用 recursive_sum(stats, weight, distributed=true)
↳ 进入 isinstance(obj, dict) 分支,依次执行 (2) 和 (5)

    # "l1_loss"
    (2)调用 recursive_sum([tensor(0.044), tensor(0.023)], weight, distributed=true)
    ↳ 进入 isinstance(obj, (tuple, list)) 分支,依次执行 (3) 和 (4)

        (3)调用 recursive_sum(tensor(0.044), weight, distributed=true)
        ↳ 进入 isinstance(obj, torch.Tensor) 分支
        ↳ 执行 torch.distributed.all_reduce 语句

        (4)调用 recursive_sum(tensor(0.023), weight, distributed=true)
        ↳ 进入 isinstance(obj, torch.Tensor) 分支
        ↳ 执行 torch.distributed.all_reduce 语句

    # "snr_loss"
    (5)调用 recursive_sum([tensor(-2.092), tensor(-1.899)], weight, distributed=true)
    ↳ 进入 isinstance(obj, (tuple, list)) 分支,依次执行 (6) 和 (7)

        (6)调用 recursive_sum(tensor(-2.092), weight, distributed=true)
        ↳ 进入 isinstance(obj, torch.Tensor) 分支
        ↳ 执行 torch.distributed.all_reduce 语句

        (7)调用 recursive_sum(tensor(-1.899), weight, distributed=true)
        ↳ 进入 isinstance(obj, torch.Tensor) 分支
        ↳ 执行 torch.distributed.all_reduce 语句

对比两个 GPU 上的 recursive_sum 迭代过程,我们就很容易发现:

  • GPU 0 在第二次执行 torch.distributed.all_reduce 语句的时候对应的数据是 "snr_loss" 的值列表中第一个元素
  • GPU 1 在第二次执行 torch.distributed.all_reduce 语句的时候对应的数据是 "l1_loss" 的值列表中第二个元素

这种将属于不同 loss 的数据聚合在一起的行为显然是不合理的(“键值错位”),并且这也能解释为什么用最初修改的代码训练时,日志中显示所有 l1_loss 都是负值。
因为上述不合理的聚合过程在训练中很容易出现,并且 L1 loss 通常都是绝对值比较小的正数,而 SNR loss 则通常是绝对值比较大的负数,两者聚合在一起就很容易导致最终结果也是负数。

而在后来修改的代码中,我通过确保每个 batch 的 stats 字典中都同时保存 l1_loss 和 snr_loss 两种数据,避免了 AllReduce 的数据聚合过程出现“键值错位”的问题,也就不会遇到开头描述的那种匪夷所思的现象了。