PyTorch with torch.no_grad 用法

PyTorch with torch.no_grad 用法

在PyTorch中,tensor有一个requires_grad参数,如果设置为True,则反向传播时,该tensor就会自动求导。
tensor的requires_grad的属性默认为False,若一个节点(叶子变量:自己创建的tensor)requires_grad被设置为True,那么所有依赖它的节点requires_grad都为True(即使其他相依赖的tensor的requires_grad = False)。

当requires_grad设置为False时,反向传播时就不会自动求导,因此可以节约内存。

with torch.no_grad的作用。在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。

即使一个tensor(命名为x)的requires_grad = True,在with torch.no_grad计算,由x得到的新tensor(命名为w标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。例子如下所示:

import torch

x = torch.randn(5, 3, requires_grad = True)
y = torch.randn(5, 3, requires_grad = True)
z = torch.randn(5, 3, requires_grad = True)

with torch.no_grad():
    w = x + y + z
    print(w.requires_grad)
    print(w.grad_fn)

print(w.requires_grad)

print(z)
print(y)
print(z)
print(w)

关于python中with的用法:

在with语句中的操作代码执行前,先执行__enter__中的代码;操作代码执行完后,再执行__exit__中的代码
__enter__=>with=>__exit__
class Sample:
    def __enter__(self):
        print("In__enter__()")
        return self
    
    def __exit__(self, type,value, trace):
        print("In__exit__()")
    
    def doSomething(self):
        a = 1/2
        return a

with Sample() as sample:
    print(sample.doSomething())
In__enter__()
0.5
In__exit__()

发表回复

您的电子邮箱地址不会被公开。