summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-06-10 15:22:18 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-06-10 15:22:18 -0400
commit9be0db7523abdfa59c19115585f1ee96d73d08c6 (patch)
treecabc1b8db78ad56fde1fa6a3bc2760a2883d4d5c
parentc5e1cd9c8c896096ad1630909a655b06eb398abb (diff)
downloadtext-rnn-9be0db7523abdfa59c19115585f1ee96d73d08c6.tar.gz
text-rnn-9be0db7523abdfa59c19115585f1ee96d73d08c6.zip
Changes
-rw-r--r--.gitignore1
-rw-r--r--lstm.py75
-rwxr-xr-xtrain.py64
3 files changed, 88 insertions, 52 deletions
diff --git a/.gitignore b/.gitignore
index cec3f1e..3d6982a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
*.pyc
*.swp
data/*
+model_data/*
diff --git a/lstm.py b/lstm.py
index 32cdb9b..e294793 100644
--- a/lstm.py
+++ b/lstm.py
@@ -13,36 +13,48 @@ from blocks.graph import ComputationGraph, apply_noise, apply_dropout
# An epoch will be composed of 'num_seqs' sequences of len 'seq_len'
# divided in chunks of lengh 'seq_div_size'
-num_seqs = 10
+num_seqs = 20
seq_len = 2000
seq_div_size = 100
io_dim = 256
-hidden_dims = [512, 512]
+hidden_dims = [512, 512, 512]
activation_function = Tanh()
-all_hidden_for_output = False
+i2h_all = True # input to all hidden layers or only first layer
+h2o_all = True # all hiden layers to output or only last layer
w_noise_std = 0.01
i_dropout = 0.5
-step_rule = 'adadelta'
+step_rule = 'momentum'
+learning_rate = 0.1
+momentum = 0.9
-param_desc = '%s-%sHO-n%s-d%s-%dx%d(%d)-%s' % (
+param_desc = '%s-%sIH,%sHO-n%s-d%s-%dx%d(%d)-%s' % (
repr(hidden_dims),
- 'all' if all_hidden_for_output else 'last',
+ 'all' if i2h_all else 'first',
+ 'all' if h2o_all else 'last',
repr(w_noise_std),
repr(i_dropout),
num_seqs, seq_len, seq_div_size,
step_rule
)
+save_freq = 5
+
+# parameters for sample generation
+sample_len = 60
+sample_temperature = 0.3
+
if step_rule == 'rmsprop':
step_rule = RMSProp()
elif step_rule == 'adadelta':
step_rule = AdaDelta()
+elif step_rule == 'momentum':
+ step_rule = Momentum(learning_rate=learning_rate, momentum=momentum)
else:
assert(False)
@@ -52,7 +64,9 @@ class Model():
in_onehot = tensor.eq(tensor.arange(io_dim, dtype='int16').reshape((1, 1, io_dim)),
inp[:, :, None])
+ in_onehot.name = 'in_onehot'
+ # Construct hidden states
dims = [io_dim] + hidden_dims
states = [in_onehot.dimshuffle(1, 0, 2)]
bricks = []
@@ -65,38 +79,44 @@ class Model():
linear = Linear(input_dim=dims[i-1], output_dim=4*dims[i],
name="lstm_in_%d"%i)
+ bricks.append(linear)
+ inter = linear.apply(states[-1])
+
+ if i2h_all and i > 1:
+ linear2 = Linear(input_dim=dims[0], output_dim=4*dims[i],
+ name="lstm_in0_%d"%i)
+ bricks.append(linear2)
+ inter = inter + linear2.apply(states[0])
+ inter.name = 'inter_bis_%d'%i
+
lstm = LSTM(dim=dims[i], activation=activation_function,
name="lstm_rec_%d"%i)
+ bricks.append(lstm)
- new_states, new_cells = lstm.apply(linear.apply(states[-1]),
+ new_states, new_cells = lstm.apply(inter,
states=init_state,
cells=init_cell)
updates.append((init_state, new_states[-1, :, :]))
updates.append((init_cell, new_cells[-1, :, :]))
states.append(new_states)
- bricks = bricks + [linear, lstm]
- states = [s.dimshuffle(1, 0, 2).reshape((inp.shape[0] * inp.shape[1], dim))
- for dim, s in zip(dims, states)]
+ states = [s.dimshuffle(1, 0, 2) for s in states]
- if all_hidden_for_output:
- top_linear = MLP(dims=[sum(hidden_dims), io_dim],
- activations=[Softmax()],
- name="pred_mlp")
+ # Construct output from hidden states
+ out = None
+ layers = zip(dims, states)[1:]
+ if not h2o_all:
+ layers = [layers[-1]]
+ for i, (dim, state) in enumerate(layers):
+ top_linear = Linear(input_dim=dim, output_dim=io_dim,
+ name='top_linear_%d'%i)
bricks.append(top_linear)
+ out_i = top_linear.apply(state)
+ out = out_i if out is None else out + out_i
+ out.name = 'out_part_%d'%i
- out = top_linear.apply(tensor.concatenate(states[1:], axis=1))
- else:
- top_linear = MLP(dims=[hidden_dims[-1], io_dim],
- activations=[None],
- name="pred_mlp")
- bricks.append(top_linear)
-
- out = top_linear.apply(states[-1])
-
- out = out.reshape((inp.shape[0], inp.shape[1], io_dim))
-
+ # Do prediction and calculate cost
pred = out.argmax(axis=2)
cost = Softmax().categorical_cross_entropy(inp[:, 1:].flatten(),
@@ -104,13 +124,13 @@ class Model():
io_dim)))
error_rate = tensor.neq(inp[:, 1:].flatten(), pred[:, :-1].flatten()).mean()
- # Initialize
+ # Initialize all bricks
for brick in bricks:
brick.weights_init = IsotropicGaussian(0.1)
brick.biases_init = Constant(0.)
brick.initialize()
- # apply noise
+ # Apply noise and dropout
cg = ComputationGraph([cost, error_rate])
if w_noise_std > 0:
noise_vars = VariableFilter(roles=[WEIGHT])(cg)
@@ -123,6 +143,7 @@ class Model():
self.error_rate = error_rate
self.cost_reg = cost_reg
self.error_rate_reg = error_rate_reg
+ self.out = out
self.pred = pred
self.updates = updates
diff --git a/train.py b/train.py
index 7857f3f..a8e9ef2 100755
--- a/train.py
+++ b/train.py
@@ -5,14 +5,17 @@ import numpy
import sys
import importlib
+from contextlib import closing
+
import theano
from theano import tensor
+from theano.tensor.shared_randomstreams import RandomStreams
-from blocks.dump import load_parameter_values
-from blocks.dump import MainLoopDumpManager
+from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER
from blocks.extensions import Printing, SimpleExtension
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-from blocks.extensions.plot import Plot
+from blocks.extras.extensions.plot import Plot
+from blocks.extensions.saveload import Checkpoint, Load
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
@@ -37,10 +40,14 @@ class GenText(SimpleExtension):
self.init_text = init_text
self.max_bytes = max_bytes
- cg = ComputationGraph([model.pred])
+
+ out = model.out[:, -1, :] / numpy.float32(config.sample_temperature)
+ prob = tensor.nnet.softmax(out)
+
+ cg = ComputationGraph([prob])
assert(len(cg.inputs) == 1)
assert(cg.inputs[0].name == 'bytes')
- self.f = theano.function(inputs=cg.inputs, outputs=[model.pred])
+ self.f = theano.function(inputs=cg.inputs, outputs=[prob])
super(GenText, self).__init__(**kwargs)
@@ -49,22 +56,21 @@ class GenText(SimpleExtension):
dtype='int16')[None, :].repeat(axis=0, repeats=config.num_seqs)
while v.shape[1] < self.max_bytes:
- pred, = self.f(v)
- v = numpy.concatenate([v, pred[:, -1:]], axis=1)
+ prob, = self.f(v)
+ prob = prob / 1.00001
+ pred = numpy.zeros((prob.shape[0],), dtype='int16')
+ for i in range(prob.shape[0]):
+ pred[i] = numpy.random.multinomial(1, prob[i, :]).nonzero()[0][0]
+ v = numpy.concatenate([v, pred[:, None]], axis=1)
for i in range(v.shape[0]):
print "Sample:", ''.join([chr(int(v[i, j])) for j in range(v.shape[1])])
-def train_model(m, train_stream, load_location=None, save_location=None):
+def train_model(m, train_stream, dump_path=None):
# Define the model
model = Model(m.cost)
- # Load the parameters from a dumped model
- if load_location is not None:
- logger.info('Loading parameters...')
- model.set_param_values(load_parameter_values(load_location))
-
cg = ComputationGraph(m.cost_reg)
algorithm = GradientDescent(cost=m.cost_reg,
step_rule=config.step_rule,
@@ -72,11 +78,26 @@ def train_model(m, train_stream, load_location=None, save_location=None):
algorithm.add_updates(m.updates)
+ # Load the parameters from a dumped model
+ if dump_path is not None:
+ try:
+ logger.info('Loading parameters...')
+ with closing(numpy.load(dump_path)) as source:
+ param_values = {'/' + name.replace(BRICK_DELIMITER, '/'): source[name]
+ for name in source.keys()
+ if name != 'pkl' and not 'None' in name}
+ model.set_param_values(param_values)
+ except IOError:
+ pass
+
main_loop = MainLoop(
model=model,
data_stream=train_stream,
algorithm=algorithm,
extensions=[
+ Checkpoint(path=dump_path,
+ after_epoch=False, every_n_epochs=config.save_freq),
+
TrainingDataMonitoring(
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
prefix='train', every_n_epochs=1),
@@ -84,19 +105,14 @@ def train_model(m, train_stream, load_location=None, save_location=None):
Plot(document='tr_'+model_name+'_'+config.param_desc,
channels=[['train_cost', 'train_cost_reg'],
['train_error_rate', 'train_error_rate_reg']],
+ server_url='http://eos21:4201/',
every_n_epochs=1, after_epoch=False),
- GenText(m, '\t', 20, every_n_epochs=1, after_epoch=False)
+
+ GenText(m, ' ', config.sample_len, every_n_epochs=1, after_epoch=False)
]
)
main_loop.run()
- # Save the main loop
- if save_location is not None:
- logger.info('Saving the main loop...')
- dump_manager = MainLoopDumpManager(save_location)
- dump_manager.dump(main_loop)
- logger.info('Saved')
-
if __name__ == "__main__":
# Build datastream
@@ -114,8 +130,6 @@ if __name__ == "__main__":
m.pred.name = 'pred'
# Train the model
- saveloc = 'model_data/%s' % model_name
- train_model(m, train_stream,
- load_location=None,
- save_location=None)
+ saveloc = 'model_data/%s-%s' % (model_name, config.param_desc)
+ train_model(m, train_stream, dump_path=saveloc)