Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions examples/keras/NTM/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
### Changelog 0.2:
* API CHANGE: Controller models now must have linear activation. The activation of the NTM-Layer is selected
by the new parameter "activation" (default: "linear"). For all the stuff that interacts with the memory we now
have very precise handselected activations which asume that there was no prior de-linearisation.
This requirement on the controller will probably be final.
* There is now support for multiple read/write heads! Use the parameters read_heads resp. write_heads at initialisation
(by default both are 1).
* The code around controller output splitting and activation was completely rewritten and cleaned from a lot of
copy-paste-code.
* Unfortunately we lost backend neutrality: As tf.slice is used extensivly, we have to either try getting K.slice or
have to do a case distinction over backend. Use the old version if you need another backend than tensorflow! And
please write me a message.
* As less activations have to be computed, it is now a tiny little bit faster (~1%).
* Stateful models do not work anymore. Actually they never worked, the testing routine was just broken. Will be repaired
asap.

# The Neural Turing Machine
### Introduction
This code tries to implement the Neural Turing Machine, as found in
https://arxiv.org/abs/1410.5401, as a backend neutral recurrent keras layer.

A very default experiment, the copy task, is provided, too.

In the end there is a TODO-List. Help would be appreciated!



### User guide
For a quick start on the copy task, type

python main.py -v ntm

while in a python enviroment which has tensorflow, keras and numpy.
Having tensorflow-gpu is recommend, as everything is about 20x faster.
In my case this experiment takes about 100 minutes on a NVIDIA GTX 1050 Ti.
The -v is optional and offers much more detailed information about the achieved accuracy, and also after every training
epoch.
Logging data is written LOGDIR_BASE, which is ./logs/ by default. View them with tensorboard:

tensorboard --logdir ./logs

If you've luck and not had a terrible run (that can happen, unfortunately), you now have a machine capable of copying a
given sequence! I wonder if we could have achieved that any other way ...

These results are especially interesting compared to an LSTM model: Run

python main.py lstm

This builds 3 layers of LSTM with and goes through the same testing procedure
as above, which for me resulted in a training time of approximately 1h (same GPU) and
(roughly) 100%, 100%, 94%, 50%, 50% accuracy at the respective test lengths.
This shows that the NTM has advantages over LSTM in some cases. Especially considering the LSTM model has about 807.200
trainable parameters while the NTM had a mere 3100!

Have fun playing around, maybe with other controllers? dense, double_dense and lstm are build in.


### API
From the outside, this implementation looks like a regular recurrent layer in keras.
It has however a number of non-obvious parameters:

#### Hyperparameters


* `n_width`: This is the width of the memory matrix. Increasing this increases computational complexity in O(n^2). The
controller shape is not dependant on this, making weight transfer possible.

* `m_depth`: This is the depth of the memory matrix. Increasing this increases the number of trainable weights in O(m^2). It also changes controller shape.

* `controller_model`: This parameter allows you to place a keras model of appropriate shape as the controller. The
appropriate shape can be calculated via controller_input_output_shape. If None is set, a single dense layer will be
used.

* `read_heads`: The number of read heads this NTM should have. Has quadratic influence on the number of trainable
weights. Default: 1

* `write_heads`: The number of write heads this NTM should have. Has quadratic influence on the number of trainable
weights, but for small numbers a *huge* impact. Default: 1


#### Usage

More or less minimal code example:

from keras.models import Sequential
from keras.optimizers import Adam
from ntm import NeuralTuringMachine as NTM

model = Sequential()
model.name = "NTM_-_" + controller_model.name

ntm = NTM(output_dim, n_slots=50, m_depth=20, shift_range=3,
controller_model=None,
return_sequences=True,
input_shape=(None, input_dim),
batch_size = 100)
model.add(ntm)

sgd = Adam(lr=learning_rate, clipnorm=clipnorm)
model.compile(loss='binary_crossentropy', optimizer=sgd,
metrics = ['binary_accuracy'], sample_weight_mode="temporal")

What if we instead want a more complex controller? Design it, e.g. double LSTM:

controller = Sequential()
controller.name=ntm_controller_architecture
controller.add(LSTM(units=150,
stateful=True,
implementation=2, # best for gpu. other ones also might not work.
batch_input_shape=(batch_size, None, controller_input_dim)))
controller.add(LSTM(units=controller_output_dim,
activation='linear',
stateful=True,
implementation=2)) # best for gpu. other ones also might not work.

controller.compile(loss='binary_crossentropy', optimizer=sgd,
metrics = ['binary_accuracy'], sample_weight_mode="temporal")

And now use the same code as above, only with controller_model=controller.

Note that we used linear as the last activation layer! This is of critical importance.
The activation of the NTM-layer can be set the parameter activation (default: linear).

Note that a correct controller_input_dim and controller_output_dim can be calculated via controller_input_output_shape:

from ntm import controller_input_output_shape
controller_input_dim, controller_output_dim = ntm.controller_input_output_shape(
input_dim, output_dim, m_depth, n_slots, shift_range, read_heads, write_heads)


Also note that every statefull controller must carry around his own state, as was done here with

stateful=True





## TODO:
- [x] Arbitrary number of read and write heads
- [ ] Support of masking, and maybe dropout, one has to reason about it theoretically first.
- [ ] Support for get and set config to better enable model saving
- [x] A bit of code cleaning: especially the controller output splitting is ugly as hell.
- [x] Support for arbitrary activation functions would be nice, currently restricted to sigmoid.
- [ ] Make it backend neutral again! Some testing might be nice, too.
- [ ] Maybe add the other experiments of the original paper?
- [ ] Mooaaar speeeed. Look if there are platant performance optimizations possible.
48 changes: 48 additions & 0 deletions examples/keras/NTM/copyTask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np


def get_sample(batch_size=128, in_bits=10, out_bits=8, max_size=20, min_size=1):
# in order to be a generator, we start with an endless loop:
while True:
# generate samples with random length.
# there a two flags, one for the beginning of the sequence
# (only second to last bit is one)
# and one for the end of the sequence (only last bit is one)
# every other time those are always zero.
# therefore the length of the generated sample is:
# 1 + actual_sequence_length + 1 + actual_sequence_length

# make flags
begin_flag = np.zeros((1, in_bits))
begin_flag[0, in_bits-2] = 1
end_flag = np.zeros((1, in_bits))
end_flag[0, in_bits-1] = 1

# initialize arrays: for processing, every sequence must be of the same length.
# We pad with zeros.
temporal_length = max_size*2 + 2
# "Nothing" on our band is represented by 0.5 to prevent immense bias towards 0 or 1.
inp = np.ones((batch_size, temporal_length, in_bits))*0.5
out = np.ones((batch_size, temporal_length, out_bits))*0.5
# sample weights: in order to make recalling the sequence much more important than having everything set to 0
# before and after, we construct a weights vector with 1 where the sequence should be recalled, and small values
# anywhere else.
sw = np.ones((batch_size, temporal_length))*0.01

# make actual sequence
for i in range(batch_size):
ts = np.random.randint(low=min_size, high=max_size+1)
actual_sequence = np.random.uniform(size=(ts, out_bits)) > 0.5
output_sequence = np.concatenate((np.ones((ts+2, out_bits))*0.5, actual_sequence), axis=0)

# pad with zeros where only the flags should be one
padded_sequence = np.concatenate((actual_sequence, np.zeros((ts, 2))), axis=1)
input_sequence = np.concatenate((begin_flag, padded_sequence, end_flag), axis=0)


# this embedds them, padding with the neutral value 0.5 automatically
inp[i, :input_sequence.shape[0]] = input_sequence
out[i, :output_sequence.shape[0]] = output_sequence
sw[i, ts+2 : ts+2+ts] = 1

yield inp, out, sw
91 changes: 91 additions & 0 deletions examples/keras/NTM/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import argparse


from keras.layers.core import Dense
from keras.layers.recurrent import LSTM
from keras.models import Sequential
from keras.optimizers import Adam, SGD
from keras.initializers import RandomNormal

output_dim = 8
input_dim = output_dim + 2 # this is the actual input dim of the network, that includes two dims for flags
batch_size = 100
read_heads = 1
write_heads = 1

#testrange=[5,10,20,40,80,160]


parser = argparse.ArgumentParser()
parser.add_argument("modelType", help="The kind of model you want to test, either ntm, dense or lstm")
parser.add_argument("-e", "--epochs", help="The number of epochs to train", default="1000", type=int)
parser.add_argument("-c", "--ntm_controller_architecture", help="""Valid choices are: dense, double_dense or
lstm. Ignored if model is not ntm""", default="dense")
parser.add_argument("-v", "--verboose", help="""Verboose training: If enabled, the model is evaluated extensively
after each training epoch.""", action="store_true")
args = parser.parse_args()
modelType = args.modelType
epochs = args.epochs
ntm_controller_architecture = args.ntm_controller_architecture
verboose = args.verboose

lr = 5e-4
clipnorm = 10
sgd = Adam(lr=lr, clipnorm=clipnorm)
sameInit = RandomNormal(seed=0)

if modelType == 'lstm':
import model_lstm
model = model_lstm.gen_model(input_dim=input_dim, output_dim=output_dim, batch_size=batch_size)

elif modelType == 'dense':
import model_dense
model = model_dense.gen_model(input_dim=input_dim, output_dim=output_dim, batch_size=batch_size)

elif modelType == 'ntm':
import model_ntm
from ntm import controller_input_output_shape as controller_shape

controller_input_dim, controller_output_dim = controller_shape(input_dim, output_dim, 20, 128, 3, read_heads,
write_heads)

controller = Sequential()
controller.name=ntm_controller_architecture
if ntm_controller_architecture == "dense":
controller.add(Dense(units=controller_output_dim,
kernel_initializer=sameInit,
bias_initializer=sameInit,
activation='linear',
input_dim=controller_input_dim))
elif ntm_controller_architecture == "double_dense":
controller.add(Dense(units=150,
kernel_initializer=sameInit,
bias_initializer=sameInit,
activation='linear',
input_dim=controller_input_dim))
controller.add(Dense(units=controller_output_dim,
kernel_initializer=sameInit,
bias_initializer=sameInit,
activation='linear'))
elif ntm_controller_architecture == "lstm":
controller.add(LSTM(units=controller_output_dim,
kernel_initializer='random_normal',
bias_initializer='random_normal',
activation='linear',
stateful=True,
implementation=2, # best for gpu. other ones also might not work.
batch_input_shape=(batch_size, None, controller_input_dim)))
else:
raise ValueError("This controller_architecture is not implemented.")

controller.compile(loss='binary_crossentropy', optimizer=sgd, metrics = ['binary_accuracy'], sample_weight_mode="temporal")

model = model_ntm.gen_model(input_dim=input_dim, output_dim=output_dim, batch_size=batch_size,
controller_model=controller, read_heads=read_heads, write_heads=write_heads,
activation="sigmoid")
else:
raise ValueError("this model is not implemented")

print("model built, starting the copy experiment")
from testing_utils import lengthy_test
lengthy_test(model, epochs=epochs, verboose=verboose)
25 changes: 25 additions & 0 deletions examples/keras/NTM/model_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import keras
from keras.models import Sequential
from keras.layers import Activation, Dense
from keras.optimizers import Adam

lr = 4e-4
clipnorm = 10
units = 256



def gen_model(input_dim=10, output_dim=8, batch_size=100):
model_dense = Sequential()
model_dense.name = "FFW"
model_dense.batch_size = batch_size
model_dense.input_dim = input_dim
model_dense.output_dim = output_dim

model_dense.add(Dense(input_shape=(None, input_dim), units=output_dim))
model_dense.add(Activation('sigmoid'))

sgd = Adam(lr=lr, clipnorm=clipnorm)
model_dense.compile(loss='binary_crossentropy', optimizer=sgd, metrics = ['binary_accuracy'], sample_weight_mode="temporal")

return model_dense
28 changes: 28 additions & 0 deletions examples/keras/NTM/model_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import keras
from keras.models import Sequential
from keras.layers import LSTM, Activation
from keras.optimizers import Adam

batch_size = 100
lr = 5e-4
clipnorm = 10
units = 256



def gen_model(input_dim=10, output_dim=8, batch_size=100):
model_LSTM = Sequential()
model_LSTM.name = "LSTM"
model_LSTM.batch_size = batch_size
model_LSTM.input_dim = input_dim
model_LSTM.output_dim = output_dim

model_LSTM.add(LSTM(input_shape=(None, input_dim), units=units, return_sequences=True))
model_LSTM.add(LSTM(units=units, return_sequences=True))
model_LSTM.add(LSTM(units=output_dim, return_sequences=True))
model_LSTM.add(Activation('sigmoid'))

sgd = Adam(lr=lr, clipnorm=clipnorm)
model_LSTM.compile(loss='binary_crossentropy', optimizer=sgd, metrics = ['binary_accuracy'], sample_weight_mode="temporal")

return model_LSTM
47 changes: 47 additions & 0 deletions examples/keras/NTM/model_ntm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np

from keras.layers.core import Activation
from keras.layers.wrappers import TimeDistributed
from keras.models import Sequential
from keras.optimizers import Adam
from keras import backend as K
import keras

from ntm import NeuralTuringMachine as NTM


n_slots = 128
m_depth = 20
learning_rate = 5e-4
clipnorm = 10

def gen_model(input_dim, batch_size, output_dim,
n_slots=n_slots,
m_depth=m_depth,
controller_model=None,
activation="sigmoid",
read_heads = 1,
write_heads = 1):

model = Sequential()
model.name = "NTM_-_" + controller_model.name
model.batch_size = batch_size
model.input_dim = input_dim
model.output_dim = output_dim

ntm = NTM(output_dim, n_slots=n_slots, m_depth=m_depth, shift_range=3,
controller_model=controller_model,
activation=activation,
read_heads = read_heads,
write_heads = write_heads,
return_sequences=True,
input_shape=(None, input_dim),
batch_size = batch_size)
model.add(ntm)

sgd = Adam(lr=learning_rate, clipnorm=clipnorm)
model.compile(loss='binary_crossentropy', optimizer=sgd, metrics = ['binary_accuracy'], sample_weight_mode="temporal")

return model


Loading