在PyTorch中,tensor有一个requires_grad参数,如果设置为True,则反向传播时,该tensor就会自动求导。
tensor的requires_grad的属性默认为False,若一个节点(叶子变量:自己创建的tensor)requires_grad被设置为True,那么所有依赖它的节点requires_grad都为True(即使其他相依赖的tensor的requires_grad = False)。
当requires_grad设置为False时,反向传播时就不会自动求导,因此可以节约内存。
with torch.no_grad的作用。在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。
即使一个tensor(命名为x)的requires_grad = True,在with torch.no_grad计算,由x得到的新tensor(命名为w标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。例子如下所示:
import torch x = torch.randn(5, 3, requires_grad = True) y = torch.randn(5, 3, requires_grad = True) z = torch.randn(5, 3, requires_grad = True) with torch.no_grad(): w = x + y + z print(w.requires_grad) print(w.grad_fn) print(w.requires_grad) print(z) print(y) print(z) print(w)
关于python中with的用法:
在with语句中的操作代码执行前,先执行__enter__中的代码;操作代码执行完后,再执行__exit__中的代码 __enter__=>with=>__exit__
class Sample: def __enter__(self): print("In__enter__()") return self def __exit__(self, type,value, trace): print("In__exit__()") def doSomething(self): a = 1/2 return a with Sample() as sample: print(sample.doSomething())
In__enter__() 0.5 In__exit__()