[pytorch]高阶导数极其缓慢,呈指数增长

2023-12-07 376 views
2

🐛 描述错误

在我的应用程序中,我需要获取函数的 n 阶混合导数。然而,我发现 torch.autograd.grad 计算时间随着 n 的增加呈指数增加。这是预期的吗?有什么办法可以解决吗?

这是我区分函数 self.F 的代码(从 R^n -> R^1):

def differentiate(self, x):
    x.requires_grad_(True)
    xi = [x[...,i] for i in range(x.shape[-1])]
    dyi = self.F(torch.stack(xi, dim=-1))
    for i in range(self.dim):
        start_time = time.time()
        dyi = torch.autograd.grad(dyi.sum(), xi[i], retain_graph=True, create_graph=True)[0]
        grad_time = time.time() - start_time
        print(grad_time)
    return dyi

这些是上述循环每次迭代打印的时间:

0.0037012100219726562
0.005133152008056641
0.008165121078491211
0.019922733306884766
0.059255123138427734
0.1910409927368164
0.6340939998626709
2.1612229347229004
11.042078971862793

我认为这是因为计算图的大小正在增加?有没有办法解决?我认为我可以通过使用 torch.func.grad 采取函数方法(大概不需要计算图)来规避这个问题。然而,这实际上增加了相同代码的运行时间!我没有正确理解 torch.func.grad 吗?

版本

torch 2.1.0

回答

8

据我了解,您必须计算和存储的梯度数量呈指数增长,公式如下num_params ** order

9

现在就结束了,正如预期的那样——当你进行更高阶的梯度时,你将需要通过越来越大的图形进行反向传播。

3

为什么会这样?它一定比仅仅更复杂num_params ** order。例如,如果您有一个带有多项式的多项式n,那么您可以轻松计算k线性时间内的 - 阶导数。我可以看到像这样的函数sin(x^2)如何爆炸,因为一阶导数是2x*cos(x^2),第二阶是-4x*sin(x^2)+2*cos(x^2),所以你最终会为你所采用的每个导数得到更多项,作为导致指数时间的分支因子(而且,我假设火炬确实不要组合相似的项来缩小计算图)。在大多数人的情况下,像这样的非线性函数只在激活函数中发挥作用。如果定义一个层f(w*x+b),其中 f 是激活,则导数 d/dx 为w*f'(w*x+b)w不依赖于x,因此推测项数只能根据您在激活函数中的选择而增长。这是否意味着激活函数的选择对于高阶导数的速度至关重要,并且激活函数的某些选择可以显着减少运行时间?我希望relu速度会很快,因为它不会增加术语数量。

4

我可能误解了你的代码。我看到你的成绩是根据上一级成绩的总和得出的,所以你不会遇到天真的指数梯度问题。

你可能是对的,我认为毕业生最多比前锋拥有更多的操作员,大约 3-4 个。从经验来看,每个订单的连续执行时间似乎会增加 3-4 倍。

您可以使用更简单的前向函数(例如 relu)来共享时间吗?

6

当然,我刚刚用 ReLU 运行了它:

0.01187586784362793
0.00047016143798828125
0.00043773651123046875
0.0005078315734863281
0.00047707557678222656
0.00044083595275878906
0.00038886070251464844
0.00038313865661621094
0.0003960132598876953

(我之前用的是Mish)

4

有趣的是,它似乎保持大致不变。我想说问题解决了吗?

但是,如果有人能够将文献中的结果链接到关于恒定因子开销的证明,那将非常感激(一直试图在网络上搜索它但无济于事)。

我想这个常数毕竟取决于函数。如果运算的导数没有封闭的解析形式,则常数可能是无限的。

0

这很有趣,但不幸的是 Mish/ReLU 仅用于测试。在我的应用程序中,我需要使用相当复杂的激活函数(erf(x) 的 n 次积分)。

这更具推测性,但如果我们可以提供 n 阶导数的解析解,是否可以显着减少计算时间?我想知道这是否能让我们绕过“分支因子”,即火炬试图获取每一项的导数。我意识到实现起来会相当棘手,因为 torch 目前是按顺序求导的,所以我主要是问这在理论上是否可行。

9

哦,这很聪明!因此,如果您为 n 以内的每个导数实现backward(),将结果提供为另一个自定义模块的函数,那么可以消除分支因子吗?

2

也许 - 这也取决于你的向后实现的效率(可能没有定制的 CUDA 内核,尽管 PT2 Triton 可能会在这里帮助你 - 你可以尝试torch.compile手动编码/分析派生的向后函数:thinking:)

如果可以加快速度,您可能还想尝试近似。

但这很难说,因为最终 GPU 只执行原始数学运算,并且通常以很高的成本来模拟其余部分。

6

向后的实现不是问题,我已经有了所有导数的分析形式。我只是想确认这样的逻辑:这将解决导致毕业时间呈指数增长的“分支问题”。

感谢您的帮助!

1

酷,让我们知道进展如何。很想知道您是否可以实现所需的加速。

2

好的,结果出来了。这些是我的自定义激活函数没有任何优化的时间:

0.01146697998046875
0.005676984786987305
0.016211986541748047
0.04466104507446289
0.14060378074645996
0.4138617515563965
1.1940059661865234
3.6735780239105225
11.722836256027222

这些是我实现可以递归调用的自定义向后传递时的结果:

0.01222991943359375
0.00305938720703125
0.005202054977416992
0.009228229522705078
0.01624774932861328
0.03122401237487793
0.10631084442138672
0.14510798454284668
0.35104799270629883

显着的加速,但对于极高阶导数仍然不可行!我意识到还有一个单独的问题。虽然这解决了由于激活导数中的多个项而导致每次向后传递的“分支因子”,但它无法解决链式法则引入的问题:如果有多个层,则应用链式法则,因此您得到多项取决于 x。在下一遍中,再次应用链式法则,并将乘积法则应用于您当前拥有的两项。我只看到两种解决方法:1)在网络中仅使用一种激活,或者 2)使用 ReLU,它不需要应用链式法则,因为它是线性的。不幸的是,在我的应用程序中,我必须使用这一特定的激活函数,因此,如果我不希望高阶导数的计算时间呈指数增长,我必须使用仅具有一个激活函数的两层 MLP。以下是相同网络的 2 层而不是 4 层的结果(我确认 3 层也会呈指数爆炸):

0.010089874267578125
0.0007302761077880859
0.0004699230194091797
0.00039124488830566406
0.00028896331787109375
0.00026702880859375
0.0002627372741699219
0.0001480579376220703
0.0001418590545654297
2

在下一遍中,再次应用链式法则,并将乘积法则应用于您当前拥有的两项

是的,基本上,由于链式法则,梯度本身会创建一个分支因子。我不知道如何克服这个问题。

9

嗯,我不明白为什么会这样。如果我们使用反向传播算法来计算高阶导数相对于函数参数的梯度:

$\eta 0(x) = F {W}(x)$

$\eta_N(x) = \sum_W \nabla W \eta {N-1}(x) = \sum k \sum {W k} \frac{\partial \eta {N-1}(x)}{\部分 a n} \cdot (\prod {k < j \le n} \frac{\partial a_j}{\partial o_j} \cdot \frac{\partial o j}{\partial a {j-1}} ) \cdot \frac{\partial a_k}{\partial W_k}$

$= \sum k \sum {W k} \frac{\partial \eta {k-1}(x)}{\partial a n} \cdot (\prod {k < j \le n} \frac{\部分 a_j}{\部分 o_j} \cdot W_j ) \cdot (\frac{\部分 a_k}{\部分 o k} \otimes a {k-1})$

$= \sum k \sum {W_k} (\bar a_k \odot \frac{\partial a_k}{\partial o k}) \otimes a {k-1}$

在哪里:

  • $a_0 = x$
  • $o_j = W j \cdot a {j-1}$
  • $n$ 是层数
  • $a_j$ 是给定 $o_j$ 的激活函数的输出。

伴随式$\bar a_k$,通过从标量输出到输入的链式法则导出,递归定义为

$\bar a_k = \frac{\partial \eta(x)}{\partial a k} = \frac{\partial \eta(x)}{\partial a {k+1}} \cdot \frac{\部分 a {k+1}}{\部分 o {k+1}} \cdot \frac{\partial o_{k+1}}{\部分 a k}$ $= (\bar a {k+1} \odot \frac{\partial a {k+1}}{\partial o {k+1}}) \cdot W_k$

如果您直接将梯度应用于产品,那么是的,您会遇到指数梯度问题。然而,反向传播应该能够避免这种情况。

8

哦,我明白了,是因为你使用了激活函数的组合吗?在这种情况下,是的,你会遇到这个问题。解决这个问题的具体方法是为每个高阶导数都有一个自定义的梯度函数。

即$\frac{\partial a_k}{\partial o_k}[N] = A^{(N)}(o_k) = \frac{\partial A^{(N-1)}}{\partial o_k}$对于每个衍生品订单 $N$

5

我确实为每个高阶导数都有一个定制的梯度。如果我只使用自定义梯度作为一阶导数,那么性能会更差。

9

嗯,基本上,我不明白为什么你会呈指数级增长。

我意识到还有一个单独的问题。虽然这解决了由于激活导数中的多个项而导致每次向后传递的“分支因子”,但它无法解决链式法则引入的问题:如果有多个层,则应用链式法则,因此您得到多项取决于 x。a 我不明白你的意思。从上面可以看出,由于反向传播,我们没有任何链式法则的东西。

我怀疑真正的原因是您的自定义向后实际上是通过 autograd 导出的,从而导致按照原始问题的分支规则。

0

具有 2 个激活函数的 MLP 三阶导数

这证明了我正在谈论的内容。无论你的激活 f 是什么(假设它不是分段线性的,在这种情况下你可以忽略链式法则),由于链式法则和乘积法则的组合,你会得到大量的项。我现在对 f 的每个导数都有一个有效的实现,但这并没有改变调用这些导数的术语数量呈爆炸式增长的事实。

6

不,我认为这是错误的——这是由于符号分化造成的。反向传播不会遇到这个问题。

6

如果是这样的话,那么为什么我们要凭经验来看待这个问题呢?

在反向传播中,我们将给定模块中的导数乘以上游导数。因此,在第一遍结束时,我们将得到每个模块的导数的乘积。在下一次迭代中(对于高阶导数),我们是否需要对所有这些相乘项求导,并根据乘积规则创建更多项?

1

不会。正如您从我上面写下的方程式中看到的,反向传播使用称为伴随方法的方法来缓存中间结果。这意味着只有激活函数才能在简单的 MLP 中创建更多算子。

不过,从技术上讲,您应该看到您的计算图随着导数的数量线性增长,因为您必须向前运行,然后每个向后链接一个接一个......我想知道为什么我们没有看到这种线性增长ReLU 案例...

0

也许这个原因是它实际上使用了前向模式微分?🤔

7

哦等等,我刚刚意识到,如果你计算高阶 VJP,它也会呈指数增长?首先,将图形加倍,然后取其 VJP,再次将图形加倍,达到 4 倍,然后是 8 倍、16 倍等......

2

我预计 VJP 会比当前方法稍微慢一些,因为现在我只计算一个输入的导数,而使用 VJP,我将计算所有输入的导数,然后乘以一个单热(有效地切片)它)。输出的大小应该保持不变,但我不知道它将如何解决原来的问题。

9

我不知道伴随方法——它缓存的中间结果是什么?在二阶导数中,它如何避免对一阶导数产生的乘积中的每一项求导?

顺便说一句,我只是想提一下,使用 ReLU 的网络的二阶导数(以及所有高阶导数)为 0。

2

正确的。不管怎样,无论使用 VJP 还是前向模式,我怀疑原因实际上是因为你将高阶梯度计算为前一级别梯度的总和,所以你的图会变得更大。

无论您运行正向模式(由于链式法则)还是反向模式(由于在每个连续图上运行反向传播时将图加倍),都是如此

7
图像

这是使用伴随方法的前向 + 后向图的示例。

8

总和仅在批量维度上,就像计算损失时所做的那样。不同的批次维度永远不会相互作用,因此它不是分支因子。我可以通过运行批量大小 1 来确认这一点,我们看到了同样的问题:

0.010080099105834961
0.0022132396697998047
0.003983259201049805
0.007370948791503906
0.014194965362548828
0.028818845748901367
0.05858922004699707
0.13225102424621582
0.3321499824523926
0.7825288772583008
7

嘿!

我推荐作为一个很好的工具来可视化你的图有多大https://github.com/szagoruyko/pytorchviz 你可以添加torchviz.make_dot(dyi.sum()).view()到你的 for 循环中以查看 autograd 图在每次迭代中如何演变。每个 xi 都与顶部的 select_backward() 节点之一相关联。

但在这种情况下,问题确实在于双向后依赖于向后,因此您不断向图中添加越来越多的内容(使用 relu 时不会这样做,因为没有依赖性)。

6

在反向传播中,我们将给定模块中的导数乘以上游导数。因此,在第一遍结束时,我们将得到每个模块的导数的乘积。在下一次迭代中(对于高阶导数),我们是否需要对所有这些相乘项求导,并根据乘积规则创建更多项?

@Acciorocketships 是对的,每次应用链式法则+乘积法则时,图表基本上都会翻倍。

听起来也确实像在自定义函数中包装硬编码的向后规则之类的技术会在一定程度上减少分支因子,尽管帮助的程度将取决于特定的函数,并且非常高阶的导数仍然会出现问题。

不会。正如您从我上面写下的方程式中看到的,反向传播使用称为伴随方法的方法来缓存中间结果。这意味着只有激活函数才能在简单的 MLP 中创建更多算子。

是的,有一些缓存,因为我们保存向后的变量,但根据操作,在进行高阶导数时可能会进行更多的重新计算。例如,在a(a(x))is的三阶导数中a''(a(x))a'(x)a'(x) + a'(a(x))a''(x)),您不必要a'(x)再次计算。一个极端的例子是,如果您正在进行 n 阶导数,例如sin(sin(sin(不断重新计算 6 个量 ( cos(sin(sin, cos(sin, cos, sin(sin(sin, sin(sin, sin)。

为了尝试进行一些实际测量,下表是在sin(sin(sin有和没有缓存的情况下在 cpu 上对 (2048, 1024) 张量进行七阶导数得出的。张量很大,因为我们需要进行大量计算才能从缓存中受益。

健康)状况 时间(秒) 喜欢的人 扩张 因斯 穆尔 否定 添加
无缓存 6.186 149 7 7 7 155 3270 第478章 2199
带缓存 4.442 149 7 7 7 155 3270 第478章 2199
预调度缓存 3.365 21 7 7 7 20 2207 189 1726
单击展开并查看代码 ```python from torch.utils._python_dispatch import TorchDispatchMode from torch.overrides import TorchFunctionMode from torch.utils.weak import WeakTensorKeyDictionary from torch._dispatch.py​​thon import enable_python_dispatcher, enable_pre_dispatch from torch._C import DispatchKey import time torch.random.manual_seed( 0) ops_count = {} class SinCosCache(TorchDispatchMode): def __init__(self, pre_dispatch=False): if pre_dispatch: super().__init__(DispatchKey.PreDispatch) self.cached_ops = ( torch.ops.aten.sin.default, torch.ops.aten.cos.default, ) self.caches = {f: WeakTensorKeyDictionary() for f in self.cached_ops} self.cache_hits = {f: 0 for f in self.cached_ops} self.cache_misses = {f: 0 for f in self.cached_ops} def __torch_dispatch__(self, func, types, args=(), kwargs=None): ops_count[func] = ops_count.get(func, 0) + 1 kwargs = {} 如果 kwargs 为 None else kwargs 如果 func 在 self.caches 中:cache = self.caches[func] 如果 args[0] 不在缓存中:cache[args[0]] = func(*args, **kwargs) self.cache_misses[func] + = 1 else: self.cache_hits[func] += 1 return cache[args[0]] return func(*args, **kwargs) # 一个相当大的输入,并且在CPU上!a = torch.rand(2048, 1024, dtype=torch.float64, require_grad=True) def fn(a): return a.sin() def run(): out = fn(fn(fn(a))) for i in range(7): print(i) out = torch.autograd.grad(out.sum(), a, create_graph=True)[0] def do_time(): start = time.time() run() end = time.time() 返回结束 - 开始 def print_ops_count(): 对于 op,在 ops_count.items() 中计数: print(f"{op.__name__}: ", count) 与 enable_python_dispatcher()、enable_pre_dispatch()、SinCosCache () 作为缓存: print("缓存:",do_time()) print_ops_count() ops_count = {} print("无缓存:",do_time()) 与 enable_python_dispatcher()、enable_pre_dispatch()、SinCosCache(pre_dispatch=True)作为缓存: print("pre_dispatch 缓存:", do_time()) print_ops_count() ```

总而言之,我认为没有太多可以做的事情。硬编码高阶规则以减轻分支因子似乎是特定于用例的,并且今天的自定义函数似乎是一个足够好的扩展点来实现这一点。

一般来说,某种缓存助手可能是高阶反向传播的有用工具,但与选择性检查点一样,由于版本计数器问题,如今在 autograd 下进行缓存很容易出错。它似乎也没有足够的影响力来影响非常高阶的导数,主要是因为运行时仍然由 mul 主导,并且 mul 不能轻易缓存,因为它非常依赖于顺序,因为您必须以完全相同的顺序乘以项才能得到缓存对你的 muls 的命中。沿着这些思路思考,随机的想法是有一个机制(可能是一个子类),在向后过程中惰性地执行 muls,直到达到非线性,此时 mul 将被具体化,这样也许会有更多的机会命中缓存。

3

将这个问题保留为“需要研究”,因为存在一些潜在的不必要的高阶导数计算,并且似乎有一些通用解决方案可以通过缓存等帮助提高性能。但是,到目前为止,这些改进仅在非常合成案例。拥有更多现实生活中的用例将有助于阐明其核心用途。