Clogs and Leaks: Why My Tensors Won't Flow

Defining a class of tricky bugs in PyTorch programs

Monday, November 23, 2020 · 5 min read

Note: Jekyll and Hyde — is how this blog goes. There are weeks of rational thought and weeks of irrational ramblings. This fall has been much of the latter, for good reason, but here is a break(!) in the clouds.

Okay, okay. The truth is that I wrote this essay for my CS class this quarter (leave it to Pat to assign an essay for a CS class!). But then, I think I have reached the age where every written assignment in college is to be treated as an opportunity to say something I otherwise would not have a chance to say… and so the essay, lightly edited, finds its way to this blog.

In this essay I want to describe two kinds of tricky bugs that might creep into your PyTorch programs. I call the bugs “clogs” and “leaks.” In my mind “clogs” and “leaks” reveal an exciting possible research direction for anyone interested in designing better APIs for automatic differentiation in machine learning.

Note: The examples presented, though in principle timeless, were tested using Python 3.8.5 running PyTorch 1.6.0.

Preliminaries: PyTorch’s Pipes

If you are familiar with PyTorch internals, you can skip this section. If not, a brief review of a wonderful topic: how does PyTorch differentiate your code for gradient descent? The technical term for PyTorch’s approach is tape-based reverse-mode automatic differentiation. As you perform arithmetic computations on your variables, PyTorch tracks the intermediate values in a computation graph. When you want to differentiate a value, you call .backward() on that value. PyTorch then walks backwards along this computation graph, computing the derivative at each step and accumulating them according to the chain rule. Eventually, the leaf nodes of the graph contain the derivatives you asked for.

Let me give a small example. Suppose we wanted to compute $d2x^2/dx|_{x=3}$. We might write a program that looks like this:

x = torch.tensor(3., requires_grad=True)
y = 2 * x**2

This program generates the following computation graph.


When you call .backward(), the PyTorch automatic differentiation walks backwards along the graph, computing derivatives at each step. By the chain rule, the product of these gives the overall derivative we sought.



Now, consider this simple PyTorch program to compute $d(\sqrt{x} + x)/dx|_{x=4}$. What do you expect to be printed?

x = torch.tensor(4., requires_grad=True)
y = sqrt(x) + x

A casual user or AP calculus student would expect to see 1.25 printed, of course. But what actually gets printed is 1. Why?

Ah! I didn’t show you the full program: I hid the imports. It turns out that the first line of this program is from math import sqrt, not from torch import sqrt. Now, the Python standard library’s math.sqrt is not a PyTorch-differentiable function, and so PyTorch is unable to track the flow of derivatives through sqrt(x).

As a result of this bug, backpropagation gets “stuck” on the way back, and only the derivative of x, i.e. 1, is deposited. This is a clog — the gradients can’t flow! In the computation graph below, the dotted arrow represents the clog.

Clog graph

The reason calling math.sqrt() on a PyTorch tensor is not a runtime error is that PyTorch tensors implicitly convert to “raw” floating-point numbers as needed. Most of the time this is a useful and indispensable feature. But I believe this situation should at the very least raise an error or a warning. While the example I presented was reasonably straightforward, there are many different ways to “clog” backpropagation, with varying degrees of insidiousness (for example, what happens when you mutate a variable in place?). It can be a nightmare to debug such situations when something goes wrong — that is, if you notice the bug in the first place!

By the way: the celebrated “reparametrization trick” that powers variational autoencoders is really just a workaround for a gradient clog problem. To train a variational autoencoder, you need to compute the derivative of a sample of a probability distribution with respect to the distribution’s parameters (e.g. the mean $\mu$ and variance $\sigma^2$ of a Gaussian distribution). Unfortunately, naïvely sampling from a parametrized distribution abruptly truncates the computation graph with respect to the parameters, because the random number generator is not differentiable all the way through — who knows what it’s doing! The solution, is to sample from a standard unit normal distribution (where $\mu=0$ and $\sigma=1$, and then re-scale the sample by multiplying by $\sigma$ and adding $\mu$. Of course, multiplication and addition are easily differentiable, and so the gradients can now flow. Problem solved!


Now, consider this slightly more complicated PyTorch program. We are going to implement a silly reinforcement learning algorithm. Here is the situation: There is a truck driving on the road with constant velocity, and your goal is to catch up to it and drive right alongside the truck. At each timestep you are allowed to choose your velocity, and then you’re told how far you are from the truck.

The setup:

truck_velocity = torch.tensor(3.142)
truck_position = torch.tensor(2.718)

def get_measurement(car_position):
    global truck_position
    truck_position = truck_position + truck_velocity
    return torch.abs(truck_position - car_position)

And a simple online gradient-based learning algorithm:

my_velocity = torch.tensor(0.01)
my_position = torch.tensor(0.)
for i in range(500):
    my_position = my_position + my_velocity
    loss = get_measurement(my_position)
    my_velocity =\
        my_velocity.detach() - my_velocity.grad * 0.01

Unlike last time, there’s nothing up my sleeve here — this is all reasonable PyTorch code. This code actually works just fine.

But, if you run it for long enough (say, 1000 iterations), you’ll notice something odd: each step starts taking longer and longer. The algorithm is accidentally quadratic! You can see this behavior quite clearly in this graph, which shows a linear growth in iteration time from step to step (the spikes are garbage collection pauses).

Time graph, essentially linear and

How can this be? Isn’t each loop doing the same calculation?

Here is one hypothesis: if you’ve read this paper you might look to see if we’re .detach()-ing my_velocity. The .detach() function snips off all incoming edges to a node in the computation graph; essentially, creating an artificial clog. If we forget to do that, the gradients would “leak” back in time across multiple steps in the graph, all the way back to the first step, and each iteration would therefore take longer and longer — just as we’re observing.

But, alas, this is not the source of the bug: as you can see, we are detaching my_velocity when we update it. So, what’s really going on here?

It’s tricky! The leak is in my_position, which subtly depends on all previous values of my_velocity and therefore makes backpropagation compute gradients for all previous timesteps. The dataflow diagram below hopefully clarifies this point. Notice how each velocity has its parent nodes detached (thanks to the call to .detach()!), but loss still has an indirect dependence on the chain of positions.

Leak graph

Finding the correct place to insert the line my_position = my_position.detach() is left as a not-quite-trivial exercise to the reader. Beware! Putting it in the wrong place will either have no effect or cause my_velocity to always have gradient 0.

Just like memory leaks, gradient leaks can be extremely sneaky. They pop up whenever your inference is “stateful” — think of applications like physics controllers, reinforcement learning, animated graphics, RNNs, and so on. I would not be surprised if many popular implementations of such algorithms do have “gradient leak” bugs. However, the bugs usually only manifest themselves visibly when the inference passes through enough timesteps for the leak to compound. Just like a dripping tap, you might not notice your losses until you get the bill at the end of the month… and then, you need to figure out how to track down the source of the leak and figure out the right way to fix it.

Plungers and patches? An appeal for PLumbing…

In the long term, how can we protect ourselves from this class of bugs? One potential solution is to embed the API inside a language whose type system tracks the creation of the computation graph. You might be able to use well-understood techniques like taint analysis or linear types (pun not intended) which traditionally track the flow of information, to now track the flow of differentiability through the program.

Let me be slightly more concrete about this suggestion. In our “clog” example, a good type system might detect that sqrt cuts off the computation graph, and, knowing that $y$ does not directly depend on $x$ in the expected way anymore, complain at compile-time when we try to request $dy/dx$. In our “leak” example, a good type system might notice that the “old” my_position effectively goes out of scope when it is re-assigned, and therefore it might complain that an unreachable reference to it actually does persist through the new my_position. Such checks seem very reasonable to demand from a modern type system.

◊ ◊ ◊