# Clogs and Leaks: Why My Tensors Won't Flow

Defining a class of tricky bugs in PyTorch programs

#### Monday, November 23, 2020 · 5 min read

Note:Jekyll and Hyde — is how this blog goes. There are weeks of rational thought and weeks of irrational ramblings. This fall has been much of the latter, for good reason, but here is a break(!) in the clouds.Okay, okay. The truth is that I wrote this essay for my CS class this quarter (leave it to Pat to assign an essay for a CS class!). But then, I think I have reached the age where every written assignment in college is to be treated as an opportunity to say something I otherwise would not have a chance to say… and so the essay, lightly edited, finds its way to this blog.

In this essay I want to describe two kinds of tricky bugs that might creep into your PyTorch programs. I call the bugs “clogs” and “leaks.” In my mind “clogs” and “leaks” reveal an exciting possible research direction for anyone interested in designing better APIs for automatic differentiation in machine learning.

*Note: The examples presented, though in principle timeless, were tested using
Python 3.8.5 running PyTorch 1.6.0.*

## Preliminaries: PyTorch’s Pipes

If you are familiar with PyTorch internals, you can skip this section. If not,
a brief review of a wonderful topic: how does PyTorch differentiate your code
for gradient descent? The technical term for PyTorch’s approach is *tape-based
reverse-mode automatic differentiation*. As you perform arithmetic computations
on your variables, PyTorch tracks the intermediate values in a computation
graph. When you want to differentiate a value, you call `.backward()`

on that
value. PyTorch then walks *backwards* along this computation graph, computing
the derivative at each step and accumulating them according to the chain rule.
Eventually, the leaf nodes of the graph contain the derivatives you asked for.

Let me give a small example. Suppose we wanted to compute ($d2x^2/dx|_{x=3}$). We might write a program that looks like this:

```
x = torch.tensor(3., requires_grad=True)
y = 2 * x**2
y.backward()
print(x.grad)
```

This program generates the following computation graph.

When you call `.backward()`

, the PyTorch automatic differentiation walks
backwards along the graph, computing derivatives at *each* step. By the chain
rule, the product of these gives the overall derivative we sought.

## Clogs

Now, consider this simple PyTorch program to compute ($d(\sqrt{x} + x)/dx|_{x=4}$). What do you expect to be printed?

```
x = torch.tensor(4., requires_grad=True)
y = sqrt(x) + x
y.backward()
print(x.grad)
```

A casual user or AP calculus student would *expect* to see 1.25 printed, of
course. But what *actually* gets printed is 1. Why?

Ah! I didn’t show you the full program: I hid the imports. It turns out that
the first line of this program is `from math import sqrt`

, *not* ```
from torch
import sqrt
```

. Now, the Python standard library’s `math.sqrt`

is not a
PyTorch-differentiable function, and so PyTorch is unable to track the flow of
derivatives through `sqrt(x)`

.

As a result of this bug, backpropagation gets “stuck” on the way back, and only
the derivative of `x`

, i.e. 1, is deposited. This is a clog — the gradients
can’t flow! In the computation graph below, the dotted arrow represents the
clog.

The reason calling `math.sqrt()`

on a PyTorch tensor is not a runtime error is
that PyTorch tensors implicitly convert to “raw” floating-point numbers as
needed. Most of the time this is a useful and indispensable feature. But I
believe this situation should *at the very least* raise an error or a warning.
While the example I presented was reasonably straightforward, there are *many*
different ways to “clog” backpropagation, with varying degrees of insidiousness
(for example, what happens when you mutate a variable in place?). It can be a
nightmare to debug such situations when something goes wrong — that is, if you
notice the bug in the first place!

*By the way:* the celebrated “reparametrization trick” that powers variational
autoencoders is really just a workaround for a gradient clog problem. To train
a variational autoencoder, you need to compute the derivative of a sample of a
probability distribution with respect to the distribution’s parameters (e.g.
the mean ($\mu$) and variance ($\sigma^2$) of a Gaussian distribution).
Unfortunately, naïvely sampling from a parametrized distribution abruptly
truncates the computation graph with respect to the parameters, because the
random number generator is not differentiable all the way through — who *knows*
what *it’s* doing! The solution, is to sample from a standard unit normal
distribution (where ($\mu=0$) and ($\sigma=1$)), and then re-scale the sample
by multiplying by ($\sigma$) and adding ($\mu$). Of course, multiplication and
addition *are* easily differentiable, and so the gradients can now flow.
Problem solved!

## Leaks

Now, consider this slightly more complicated PyTorch program. We are going to implement a silly reinforcement learning algorithm. Here is the situation: There is a truck driving on the road with constant velocity, and your goal is to catch up to it and drive right alongside the truck. At each timestep you are allowed to choose your velocity, and then you’re told how far you are from the truck.

The setup:

```
truck_velocity = torch.tensor(3.142)
truck_position = torch.tensor(2.718)
def get_measurement(car_position):
global truck_position
truck_position = truck_position + truck_velocity
return torch.abs(truck_position - car_position)
```

And a simple online gradient-based learning algorithm:

```
my_velocity = torch.tensor(0.01)
my_position = torch.tensor(0.)
for i in range(500):
my_velocity.requires_grad_()
my_position = my_position + my_velocity
loss = get_measurement(my_position)
loss.backward()
my_velocity =\
my_velocity.detach() - my_velocity.grad * 0.01
```

Unlike last time, there’s nothing up my sleeve here — this is all reasonable PyTorch code. This code actually works just fine.

But, if you run it for long enough (say, 1000 iterations), you’ll notice something odd: each step starts taking longer and longer. The algorithm is accidentally quadratic! You can see this behavior quite clearly in this graph, which shows a linear growth in iteration time from step to step (the spikes are garbage collection pauses).

How can this be? Isn’t each loop doing the same calculation?

Here is one hypothesis: if you’ve read this
paper you might look to see if we’re
`.detach()`

-ing `my_velocity`

. The `.detach()`

function snips off all incoming
edges to a node in the computation graph; essentially, creating an artificial
clog. If we forget to do that, the gradients would “leak” back in time across
multiple steps in the graph, all the way back to the first step, and each
iteration would therefore take longer and longer — just as we’re observing.

But, alas, this is not the source of the bug: as you can see, we *are*
detaching `my_velocity`

when we update it. So, what’s really going on here?

It’s tricky! The leak is in `my_position`

, which subtly depends on *all*
previous values of `my_velocity`

and therefore makes backpropagation compute
gradients for *all* previous timesteps. The dataflow diagram below hopefully
clarifies this point. Notice how each `velocity`

has its parent nodes detached
(thanks to the call to `.detach()`

!), but `loss`

still has an indirect
dependence on the chain of `positions`

.

Finding the correct place to insert the line ```
my_position =
my_position.detach()
```

is left as a not-quite-trivial exercise to the reader.
Beware! Putting it in the *wrong* place will either have no effect *or* cause
`my_velocity`

to always have gradient 0.

Just like memory leaks, gradient leaks can be extremely sneaky. They pop up
whenever your inference is “stateful” — think of applications like physics
controllers, reinforcement learning, animated graphics, RNNs, and so on. I
would not be surprised if many popular implementations of such algorithms *do*
have “gradient leak” bugs. However, the bugs usually only manifest themselves
visibly when the inference passes through enough timesteps for the leak to
compound. Just like a dripping tap, you might not notice your losses until you
get the bill at the end of the month… and then, you need to figure out how to
track down the source of the leak and figure out the right way to fix it.

## Plungers and patches? An appeal for PLumbing…

In the long term, how can we protect ourselves from this class of bugs? One
potential solution is to embed the API inside a language whose type system
tracks the creation of the computation graph. You might be able to use
well-understood techniques like *taint analysis* or *linear types* (pun not
intended) which traditionally track the flow of *information*, to now track the
flow of *differentiability* through the program.

Let me be slightly more concrete about this suggestion. In our “clog” example,
a good type system might detect that `sqrt`

cuts off the computation graph,
and, knowing that ($y$) does not directly depend on ($x$) in the expected way
anymore, complain at compile-time when we try to request ($dy/dx$). In our
“leak” example, a good type system might notice that the “old” `my_position`

effectively goes out of scope when it is re-assigned, and therefore it might
complain that an unreachable reference to it actually does persist through the
new `my_position`

. Such checks seem very reasonable to demand from a modern
type system.