diff options
author | Thomas Mesnard <thomas.mesnard@ens.fr> | 2015-12-28 20:51:50 +0100 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2016-04-21 10:21:42 +0200 |
commit | 072d26e766931007a0f243674f7dfdff5c3104e9 (patch) | |
tree | ae3639f4ff3f8e0e3e9767c15322171aa6f2169e /main_dummy_dataset.py | |
parent | e8e37dee0c5c846b1aa2dd24dc99095191f72a9b (diff) | |
download | pgm-ctc-072d26e766931007a0f243674f7dfdff5c3104e9.tar.gz pgm-ctc-072d26e766931007a0f243674f7dfdff5c3104e9.zip |
Add plot
More TIMIT ; log domain
TIMIT: more complexity
Nice poster
Beautify code (mostly, add comments)
Add final stuff.
Diffstat (limited to 'main_dummy_dataset.py')
-rwxr-xr-x | main_dummy_dataset.py | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/main_dummy_dataset.py b/main_dummy_dataset.py new file mode 100755 index 0000000..9240358 --- /dev/null +++ b/main_dummy_dataset.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python + +import theano +import numpy +from theano import tensor +from blocks.model import Model +from blocks.bricks import Linear, Tanh +from blocks.bricks.lookup import LookupTable +from ctc import CTC +from blocks.initialization import IsotropicGaussian, Constant +from fuel.datasets import IterableDataset +from fuel.streams import DataStream +from blocks.algorithms import (GradientDescent, Scale, AdaDelta, RemoveNotFinite, + StepClipping, CompositeRule) +from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring +from blocks.main_loop import MainLoop +from blocks.extensions import FinishAfter, Printing +from blocks.bricks.recurrent import SimpleRecurrent, LSTM +from blocks.graph import ComputationGraph + +from dummy_dataset import setup_datastream + +from edit_distance import batch_edit_distance +from blocks.extras.extensions.plot import Plot + +floatX = theano.config.floatX + + +n_epochs = 10000 +num_input_classes = 5 +h_dim = 40 +rec_dim = 40 +num_output_classes = 4 + + +print('Building model ...') # ----------- THE MODEL -------------------------- + +inputt = tensor.lmatrix('input').T +input_mask = tensor.matrix('input_mask').T +y = tensor.lmatrix('output').T +y_mask = tensor.matrix('output_mask').T +y_len = y_mask.sum(axis=0) +# inputt : T x B +# input_mask : T x B +# y : L x B +# y_mask : L x B + +# Linear bricks in +input_to_h = LookupTable(num_input_classes, h_dim, name='lookup') +h = input_to_h.apply(inputt) +# h : T x B x h_dim + +# RNN bricks +pre_lstm = Linear(input_dim=h_dim, output_dim=4*rec_dim, name='LSTM_linear') +lstm = LSTM(activation=Tanh(), + dim=rec_dim, name="rnn") +rnn_out, _ = lstm.apply(pre_lstm.apply(h), mask=input_mask) + +# Linear bricks out +rec_to_o = Linear(name='rec_to_o', + input_dim=rec_dim, + output_dim=num_output_classes + 1) +y_hat_pre = rec_to_o.apply(rnn_out) +# y_hat_pre : T x B x C+1 + +# y_hat : T x B x C+1 +y_hat = tensor.nnet.softmax( + y_hat_pre.reshape((-1, num_output_classes + 1)) +).reshape((y_hat_pre.shape[0], y_hat_pre.shape[1], -1)) +y_hat.name = 'y_hat' + +y_hat_mask = input_mask + +# Cost +cost = CTC().apply_log_domain(y, y_hat, y_len, y_hat_mask).mean() +cost.name = 'CTC' + +dl, dl_length = CTC().best_path_decoding(y_hat, y_hat_mask) + +edit_distances = batch_edit_distance(dl.T.astype('int32'), dl_length, y.T.astype('int32'), + y_len.astype('int32')) +edit_distance = edit_distances.mean() +edit_distance.name = 'edit_distance' +errors_per_char = (edit_distances / y_len).mean() +errors_per_char.name = 'errors_per_char' + +L = y.shape[0] +B = y.shape[1] +dl = dl[:L, :] +is_error = tensor.neq(dl, y) * tensor.lt(tensor.arange(L)[:,None], y_len[None,:]) +is_error = tensor.switch(is_error.sum(axis=0), tensor.ones((B,)), tensor.neq(y_len, dl_length)) + +error_rate = is_error.mean() +error_rate.name = 'error_rate' + + +# Initialization +for brick in [input_to_h, pre_lstm, lstm, rec_to_o]: + brick.weights_init = IsotropicGaussian(0.01) + brick.biases_init = Constant(0) + brick.initialize() + +print('Bulding DataStream ...') # --------------------------------------------------- +ds, stream = setup_datastream(batch_size=100, + nb_examples=10000, rng_seed=123, + min_out_len=5, max_out_len=20) +valid_ds, valid_stream = setup_datastream(batch_size=100, + nb_examples=1000, rng_seed=456, + min_out_len=5, max_out_len=20) + +print('Bulding training process...') # ---------------------------------------------- +algorithm = GradientDescent(cost=cost, + parameters=ComputationGraph(cost).parameters, + step_rule=CompositeRule([RemoveNotFinite(), AdaDelta()])) + # CompositeRule([StepClipping(10.0), Scale(0.02)])) +monitor_cost = TrainingDataMonitoring([cost, error_rate], + prefix="train", + after_epoch=True) + +monitor_valid = DataStreamMonitoring([cost, error_rate, edit_distance, errors_per_char], + data_stream=valid_stream, + prefix="valid", + after_epoch=True) + +plot = Plot(document='CTC_dummy_dataset_%d_%d'%(h_dim, rec_dim), + channels=[['train_CTC', 'valid_CTC'], + ['train_error_rate', 'valid_error_rate'], + ['valid_edit_distance'], + ['valid_errors_per_char']], + after_epoch=True) + +model = Model(cost) +main_loop = MainLoop(data_stream=stream, algorithm=algorithm, + extensions=[monitor_cost, monitor_valid, plot, + FinishAfter(after_n_epochs=n_epochs), + Printing()], + model=model) + +print('Starting training ...') # --------------------------------------------------- +main_loop.run() + + +# vim: set sts=4 ts=4 sw=4 tw=0 et: |