torch(五)、Locally disabling gradient computation
- 2019 年 10 月 7 日
- 筆記
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_36670529/article/details/101199263
The context managers torch.no_grad()
, torch.enable_grad()
, and torch.set_grad_enabled()
are helpful for locally disabling and enabling gradient computation. See Locally disabling gradient computation for more details on their usage. These context managers are thread local, so they won’t work if you send work to another thread using the :module:`threading` module, etc.
Examples:
>>> x = torch.zeros(1, requires_grad=True) >>> with torch.no_grad(): ... y = x * 2 >>> y.requires_grad False >>> is_train = False >>> with torch.set_grad_enabled(is_train): ... y = x * 2 >>> y.requires_grad False >>> torch.set_grad_enabled(True) # this can also be used as a function >>> y = x * 2 >>> y.requires_grad True >>> torch.set_grad_enabled(False) >>> y = x * 2 >>> y.requires_grad False