Understanding TensorFlow Source Code: RNN Cells
You can think of this as a guided reading of the first half of the code in rnn_cell_impl.py, where some RNN, LSTM, and GRU cells are defined. I’m minimizing the code here in the hopes that you’ll follow along with the actual source. I hope that you’ll find the TensorFlow source code even more approachable.
Prep: advanced Python concepts
__call__
__call__
is a special method on Python objects; it turns them into callable
objects. callable
objects, like functions or RNNCell
s, can be called with parentheses.
def my_callable(x):
return x + 1return my_callable(1)
has the same result as this
class MyClass:
def __call__(self, x):
return x + 1
my_callable = MyClass()return my_callable(5)
@property
property
is a special function that allows you to have computed object properties. For example:
class MyClass:
def __init__(self):
self.x = 5my_object = MyClass()return my_object.x
is pretty much the same thing as
class MyClass:
def __init__(self):
pass @property
def x(self):
return 5my_object = MyClass()return my_object.x
This lets you make computed object properties. For example, I could have a property that goes and fetches data from the database when it’s accessed. It’s like a “getter” in Java or C#, but more pythonic.
What’s the goal of rnn_cell_impl.py?
The goal of rnn_cell_impl.py
is to provide some popular RNN cells and an easy way for people to create their own cells.
An RNN cell is basically a function that takes in an input and a state, returning a tuple of the output and the next state. In programming, this is very similar to a reducer function. You can chain a bunch of these cells together, sort of like in this pseudocode:
cell = LSTM(num_units)
inputs = [...a list of data...]
state = initial_state
result = []for input in inputs:
output, state = cell(input, state)
result.append(output)return result
There’s an unenforced requirement that shapes of the current state and next state must be the same. This restriction allows us to chain like-cells together because the next unit in the chain can always use the output state from the previous unit.
It’s up to the consumers of RNN cells to take advantage of their “recurrent” nature. The cells are each only one link. So naturally, the implementations of RNN cells aren’t necessarily recurrent themselves.
RNNCell
This is an abstract class that helps people implement their own RNNCell
implementations. Instructions of how to create a concrete RNNCell
are pretty clear:
Every `RNNCell` must have the properties below and implement `call` with the signature `(output, next_state) = call(input, state)`
Those “properties below” being state_size
and output_size
.
Pretty simple. You can see a few implementations further down in rnn_cell_impl.py
.
BasicRNNCell
This is the most basic a recurrent network can be. It’s a simple implementation of RNNCell
. At its core is a linear model followed by an activation function.
Here’s that in action, excerpted from BasicRNNCell.call
:
output = self._activation(self._linear([inputs, state]))
return output, output
inputs
is the input to an RNN cell. state
is the output of the previous RNN cell. output
is just a linear combination of inputs
and state
, then passed through the activation function — so it’s basically like any fully connected layer, except with two inputs.
Why return a tuple of output, output
? We’re returning output, output
here because RNNs are expected to return a tuple of (output, state)
when called. In BasicRNNCell
, our output is the same thing as our next state; this network doesn’t discriminate between the two. This isn’t the same for every cell, but it ensures that the BasicRNNCell
works in the same places as other cells which have different outputs and states.