torch.no_grad()函数的作用

torch.no_grad()函数是禁用梯度计算的上下文管理器。

当我们确信不会调用Tensor.backward()时,禁用梯度计算很有用,因为它将减少计算的内存消耗。

在这种模式下,即使输入的向量的requires_grad=True,每次计算的结果也将为requires_grad=False。

但是,要注意:所有工厂函数或创建新张量的函数,都不受此模式的影响。

另外,提醒一下:torch.no_grad()不仅是个函数,也是一个装饰器。如下代码所示。

torch.no_grad()函数应用场景

>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False


>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False


>>> @torch.no_grad
... def tripler(x):
...     return x * 3
>>> z = tripler(x)
>>> z.requires_grad
False


>>> # 工厂函数并不受no_grad的影响
>>> with torch.no_grad():
...     a = torch.nn.Parameter(torch.rand(10))
>>> a.requires_grad
True

补充说明:什么是工厂函数?

工厂函数是用于生成tensor的函数。

常见的工厂函数有torch.rand、torch.randint、torch.randn、torch.eye等,更多介绍请移步PyTorch官网介绍:

https://pytorch.org/cppdocs/notes/tensor_creation.html#factory-functions