This post is aimed at helping someone who runs into some misleading
RuntimeErrors when using PyTorch’s
torch.distributions functions. Specifically, one particular trap/mis-leading- error-message with
torch.distributions.MultivariateNormal is discussed in this post and may apply to other distributions as well.
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead if the parameter(s) require gradients.
import torch mu = torch.tensor([0.0, 0.0], requires_grad=True) sigma = torch.eye(2) distrib = torch.distributions.MultivariateNormal(mu, sigma) a = distrib.sample() #loss = - distrib.log_prob(a) # This will run fine because a is a Tensor loss_np = - distrib.log_prob(a.numpy())
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead. with the following Traceback:
Traceback (most recent call last): File "distrib_debug.py", line 6, in <module> loss_np = - distrib.log_prob(a.numpy()) File "~/python3.5/site-packages/torch/distributions/multivariate_normal.py", line 181, in log_prob diff = value - self.loc File "~/python3.5/site-packages/torch/tensor.py", line 376, in __array__ return self.cpu().numpy() RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.
Here, we are interested in learning the parameters
mu and so we set the
requires_grad = True in order to get the gradients computed. But, when we call
distrib.log_prob(a.numpy()), we get the above
RuntimeError which is misleading since the
RuntimeError is suggesting us to
mu Tensor (Variable) from the computation graph, convert it to numpy array and then call
log_prob(...) which means the gradients for
mu will not be computed – which is not what we want.
This will work fine if
mu does not require gradients. For example, the following code will run fine:
import torch mu = torch.tensor([0.0, 0.0]) # Note: requires_grad is False by default sigma = torch.eye(2) distrib = torch.distributions.MultivariateNormal(mu, sigma ) a = distrib.sample() loss_np = - distrib.log_prob(a.numpy())
We will get the above mis-leading
RuntimeError if even one of the parameters (
sigma) requires gradient to be computed.
You may notice from the comment in the first code snippet in this post that calling
a is a Tensor will not cause this error. But this is not intuitive or explained in the logs even though we can use the
validate_args parameter to explictly require validating the input arguments to the distribution functions. The useful
validate_args argument was introduced in this PR and was merged in March 2018.
In summary, to avoid some head aches due to mis-leading
RuntimeError messages, you can set the
validate_args argument to
True when initializing a distribution like this:
distrib = torch.distributions.MultivariateNormal(mu, sigma, validate_args=True)
This will perform several sanity checks on the supplied arguments including boundary conditions and will raise useful/sensible value exceptions before they turn into puzzling
RuntimeErrors. I think by default the validation of the input arguments to the distribution functions are disabled due to performance overhead. Good choice but stating this clearly in the documentation will help the users I guess.