🐛 描述错误
在我的应用程序中,我需要获取函数的 n 阶混合导数。然而,我发现 torch.autograd.grad 计算时间随着 n 的增加呈指数增加。这是预期的吗?有什么办法可以解决吗?
这是我区分函数 self.F 的代码(从 R^n -> R^1):
def differentiate(self, x):
x.requires_grad_(True)
xi = [x[...,i] for i in range(x.shape[-1])]
dyi = self.F(torch.stack(xi, dim=-1))
for i in range(self.dim):
start_time = time.time()
dyi = torch.autograd.grad(dyi.sum(), xi[i], retain_graph=True, create_graph=True)[0]
grad_time = time.time() - start_time
print(grad_time)
return dyi
这些是上述循环每次迭代打印的时间:
0.0037012100219726562
0.005133152008056641
0.008165121078491211
0.019922733306884766
0.059255123138427734
0.1910409927368164
0.6340939998626709
2.1612229347229004
11.042078971862793
我认为这是因为计算图的大小正在增加?有没有办法解决?我认为我可以通过使用 torch.func.grad 采取函数方法(大概不需要计算图)来规避这个问题。然而,这实际上增加了相同代码的运行时间!我没有正确理解 torch.func.grad 吗?
版本
torch 2.1.0