Understanding TensorFlow Source Code: RNN Cells

--

You can think of this as a guided reading of the first half of the code in rnn_cell_impl.py, where some RNN, LSTM, and GRU cells are defined. I’m minimizing the code here in the hopes that you’ll follow along with the actual source. I hope that you’ll find the TensorFlow source code even more approachable.

Prep: advanced Python concepts

__call__

__call__ is a special method on Python objects; it turns them into callable objects. callable objects, like functions or RNNCells, can be called with parentheses.

def my_callable(x):
return x + 1
return my_callable(1)

has the same result as this

class MyClass:
def __call__(self, x):
return x + 1

my_callable = MyClass()
return my_callable(5)

@property

property is a special function that allows you to have computed object properties. For example:

class MyClass:
def __init__(self):
self.x = 5
my_object = MyClass()return my_object.x

is pretty much the same thing as

class MyClass:
def __init__(self):
pass
@property
def x(self):
return 5
my_object = MyClass()return my_object.x

This lets you make computed object properties. For example, I could have a property that goes and fetches data from the database when it’s accessed. It’s like a “getter” in Java or C#, but more pythonic.

What’s the goal of rnn_cell_impl.py?

The goal of rnn_cell_impl.py is to provide some popular RNN cells and an easy way for people to create their own cells.

An RNN cell is basically a function that takes in an input and a state, returning a tuple of the output and the next state. In programming, this is very similar to a reducer function. You can chain a bunch of these cells together, sort of like in this pseudocode:

cell = LSTM(num_units)
inputs = [...a list of data...]
state = initial_state
result = []
for input in inputs:
output, state = cell(input, state)
result.append(output)
return result

There’s an unenforced requirement that shapes of the current state and next state must be the same. This restriction allows us to chain like-cells together because the next unit in the chain can always use the output state from the previous unit.

It’s up to the consumers of RNN cells to take advantage of their “recurrent” nature. The cells are each only one link. So naturally, the implementations of RNN cells aren’t necessarily recurrent themselves.

RNNCell

This is an abstract class that helps people implement their own RNNCell implementations. Instructions of how to create a concrete RNNCell are pretty clear:

Every `RNNCell` must have the properties below and implement `call` with the signature `(output, next_state) = call(input, state)`

Those “properties below” being state_size and output_size.

Pretty simple. You can see a few implementations further down in rnn_cell_impl.py.

BasicRNNCell

This is the most basic a recurrent network can be. It’s a simple implementation of RNNCell. At its core is a linear model followed by an activation function.

Here’s that in action, excerpted from BasicRNNCell.call:

output = self._activation(self._linear([inputs, state]))
return output, output

inputs is the input to an RNN cell. state is the output of the previous RNN cell. output is just a linear combination of inputs and state , then passed through the activation function — so it’s basically like any fully connected layer, except with two inputs.

Why return a tuple of output, output? We’re returning output, output here because RNNs are expected to return a tuple of (output, state) when called. In BasicRNNCell, our output is the same thing as our next state; this network doesn’t discriminate between the two. This isn’t the same for every cell, but it ensures that the BasicRNNCell works in the same places as other cells which have different outputs and states.

--

--