aboutsummaryrefslogtreecommitdiff
path: root/main_dummy_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'main_dummy_dataset.py')
-rwxr-xr-xmain_dummy_dataset.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/main_dummy_dataset.py b/main_dummy_dataset.py
new file mode 100755
index 0000000..9240358
--- /dev/null
+++ b/main_dummy_dataset.py
@@ -0,0 +1,143 @@
+#!/usr/bin/env python
+
+import theano
+import numpy
+from theano import tensor
+from blocks.model import Model
+from blocks.bricks import Linear, Tanh
+from blocks.bricks.lookup import LookupTable
+from ctc import CTC
+from blocks.initialization import IsotropicGaussian, Constant
+from fuel.datasets import IterableDataset
+from fuel.streams import DataStream
+from blocks.algorithms import (GradientDescent, Scale, AdaDelta, RemoveNotFinite,
+ StepClipping, CompositeRule)
+from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring
+from blocks.main_loop import MainLoop
+from blocks.extensions import FinishAfter, Printing
+from blocks.bricks.recurrent import SimpleRecurrent, LSTM
+from blocks.graph import ComputationGraph
+
+from dummy_dataset import setup_datastream
+
+from edit_distance import batch_edit_distance
+from blocks.extras.extensions.plot import Plot
+
+floatX = theano.config.floatX
+
+
+n_epochs = 10000
+num_input_classes = 5
+h_dim = 40
+rec_dim = 40
+num_output_classes = 4
+
+
+print('Building model ...') # ----------- THE MODEL --------------------------
+
+inputt = tensor.lmatrix('input').T
+input_mask = tensor.matrix('input_mask').T
+y = tensor.lmatrix('output').T
+y_mask = tensor.matrix('output_mask').T
+y_len = y_mask.sum(axis=0)
+# inputt : T x B
+# input_mask : T x B
+# y : L x B
+# y_mask : L x B
+
+# Linear bricks in
+input_to_h = LookupTable(num_input_classes, h_dim, name='lookup')
+h = input_to_h.apply(inputt)
+# h : T x B x h_dim
+
+# RNN bricks
+pre_lstm = Linear(input_dim=h_dim, output_dim=4*rec_dim, name='LSTM_linear')
+lstm = LSTM(activation=Tanh(),
+ dim=rec_dim, name="rnn")
+rnn_out, _ = lstm.apply(pre_lstm.apply(h), mask=input_mask)
+
+# Linear bricks out
+rec_to_o = Linear(name='rec_to_o',
+ input_dim=rec_dim,
+ output_dim=num_output_classes + 1)
+y_hat_pre = rec_to_o.apply(rnn_out)
+# y_hat_pre : T x B x C+1
+
+# y_hat : T x B x C+1
+y_hat = tensor.nnet.softmax(
+ y_hat_pre.reshape((-1, num_output_classes + 1))
+).reshape((y_hat_pre.shape[0], y_hat_pre.shape[1], -1))
+y_hat.name = 'y_hat'
+
+y_hat_mask = input_mask
+
+# Cost
+cost = CTC().apply_log_domain(y, y_hat, y_len, y_hat_mask).mean()
+cost.name = 'CTC'
+
+dl, dl_length = CTC().best_path_decoding(y_hat, y_hat_mask)
+
+edit_distances = batch_edit_distance(dl.T.astype('int32'), dl_length, y.T.astype('int32'),
+ y_len.astype('int32'))
+edit_distance = edit_distances.mean()
+edit_distance.name = 'edit_distance'
+errors_per_char = (edit_distances / y_len).mean()
+errors_per_char.name = 'errors_per_char'
+
+L = y.shape[0]
+B = y.shape[1]
+dl = dl[:L, :]
+is_error = tensor.neq(dl, y) * tensor.lt(tensor.arange(L)[:,None], y_len[None,:])
+is_error = tensor.switch(is_error.sum(axis=0), tensor.ones((B,)), tensor.neq(y_len, dl_length))
+
+error_rate = is_error.mean()
+error_rate.name = 'error_rate'
+
+
+# Initialization
+for brick in [input_to_h, pre_lstm, lstm, rec_to_o]:
+ brick.weights_init = IsotropicGaussian(0.01)
+ brick.biases_init = Constant(0)
+ brick.initialize()
+
+print('Bulding DataStream ...') # ---------------------------------------------------
+ds, stream = setup_datastream(batch_size=100,
+ nb_examples=10000, rng_seed=123,
+ min_out_len=5, max_out_len=20)
+valid_ds, valid_stream = setup_datastream(batch_size=100,
+ nb_examples=1000, rng_seed=456,
+ min_out_len=5, max_out_len=20)
+
+print('Bulding training process...') # ----------------------------------------------
+algorithm = GradientDescent(cost=cost,
+ parameters=ComputationGraph(cost).parameters,
+ step_rule=CompositeRule([RemoveNotFinite(), AdaDelta()]))
+ # CompositeRule([StepClipping(10.0), Scale(0.02)]))
+monitor_cost = TrainingDataMonitoring([cost, error_rate],
+ prefix="train",
+ after_epoch=True)
+
+monitor_valid = DataStreamMonitoring([cost, error_rate, edit_distance, errors_per_char],
+ data_stream=valid_stream,
+ prefix="valid",
+ after_epoch=True)
+
+plot = Plot(document='CTC_dummy_dataset_%d_%d'%(h_dim, rec_dim),
+ channels=[['train_CTC', 'valid_CTC'],
+ ['train_error_rate', 'valid_error_rate'],
+ ['valid_edit_distance'],
+ ['valid_errors_per_char']],
+ after_epoch=True)
+
+model = Model(cost)
+main_loop = MainLoop(data_stream=stream, algorithm=algorithm,
+ extensions=[monitor_cost, monitor_valid, plot,
+ FinishAfter(after_n_epochs=n_epochs),
+ Printing()],
+ model=model)
+
+print('Starting training ...') # ---------------------------------------------------
+main_loop.run()
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et: