torch(五)、Locally disabling gradient computation

  • 2019 年 10 月 7 日
  • 筆記

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

本文链接:https://blog.csdn.net/weixin_36670529/article/details/101199263

The context managers torch.no_grad(), torch.enable_grad(), and torch.set_grad_enabled() are helpful for locally disabling and enabling gradient computation. See Locally disabling gradient computation for more details on their usage. These context managers are thread local, so they won’t work if you send work to another thread using the :module:`threading` module, etc.

Examples:

>>> x = torch.zeros(1, requires_grad=True)  >>> with torch.no_grad():  ...     y = x * 2  >>> y.requires_grad  False    >>> is_train = False  >>> with torch.set_grad_enabled(is_train):  ...     y = x * 2  >>> y.requires_grad  False    >>> torch.set_grad_enabled(True)  # this can also be used as a function  >>> y = x * 2  >>> y.requires_grad  True    >>> torch.set_grad_enabled(False)  >>> y = x * 2  >>> y.requires_grad  False