torch(五)、Locally disabling gradient computation

  • 2019 年 10 月 7 日
  • 筆記

版權聲明:本文為部落客原創文章,遵循 CC 4.0 BY-SA 版權協議,轉載請附上原文出處鏈接和本聲明。

本文鏈接:https://blog.csdn.net/weixin_36670529/article/details/101199263

The context managers torch.no_grad(), torch.enable_grad(), and torch.set_grad_enabled() are helpful for locally disabling and enabling gradient computation. See Locally disabling gradient computation for more details on their usage. These context managers are thread local, so they won』t work if you send work to another thread using the :module:`threading` module, etc.

Examples:

>>> x = torch.zeros(1, requires_grad=True)  >>> with torch.no_grad():  ...     y = x * 2  >>> y.requires_grad  False    >>> is_train = False  >>> with torch.set_grad_enabled(is_train):  ...     y = x * 2  >>> y.requires_grad  False    >>> torch.set_grad_enabled(True)  # this can also be used as a function  >>> y = x * 2  >>> y.requires_grad  True    >>> torch.set_grad_enabled(False)  >>> y = x * 2  >>> y.requires_grad  False