torch(五)、Locally disabling gradient computation
- 2019 年 10 月 7 日
- 筆記
版權聲明:本文為部落客原創文章,遵循 CC 4.0 BY-SA 版權協議,轉載請附上原文出處鏈接和本聲明。
本文鏈接:https://blog.csdn.net/weixin_36670529/article/details/101199263
The context managers torch.no_grad()
, torch.enable_grad()
, and torch.set_grad_enabled()
are helpful for locally disabling and enabling gradient computation. See Locally disabling gradient computation for more details on their usage. These context managers are thread local, so they won』t work if you send work to another thread using the :module:`threading` module, etc.
Examples:
>>> x = torch.zeros(1, requires_grad=True) >>> with torch.no_grad(): ... y = x * 2 >>> y.requires_grad False >>> is_train = False >>> with torch.set_grad_enabled(is_train): ... y = x * 2 >>> y.requires_grad False >>> torch.set_grad_enabled(True) # this can also be used as a function >>> y = x * 2 >>> y.requires_grad True >>> torch.set_grad_enabled(False) >>> y = x * 2 >>> y.requires_grad False