aboutsummaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py139
1 files changed, 58 insertions, 81 deletions
diff --git a/main.py b/main.py
index b71d339..e288b9b 100644
--- a/main.py
+++ b/main.py
@@ -3,143 +3,120 @@ import numpy
from theano import tensor
from blocks.model import Model
from blocks.bricks import Linear, Tanh
+from blocks.bricks.lookup import LookupTable
from ctc import CTC
from blocks.initialization import IsotropicGaussian, Constant
from fuel.datasets import IterableDataset
from fuel.streams import DataStream
from blocks.algorithms import (GradientDescent, Scale,
StepClipping, CompositeRule)
-from blocks.extensions.monitoring import TrainingDataMonitoring
+from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring
from blocks.main_loop import MainLoop
from blocks.extensions import FinishAfter, Printing
-from blocks.bricks.recurrent import SimpleRecurrent
+from blocks.bricks.recurrent import SimpleRecurrent, LSTM
from blocks.graph import ComputationGraph
-try:
- import cPickle as pickle
-except:
- import pickle
-floatX = theano.config.floatX
+from dummy_dataset import setup_datastream
+floatX = theano.config.floatX
-@theano.compile.ops.as_op(itypes=[tensor.lvector],
- otypes=[tensor.lvector])
-def print_pred(y_hat):
- blank_symbol = 4
- res = []
- for i, s in enumerate(y_hat):
- if (s != blank_symbol) and (i == 0 or s != y_hat[i - 1]):
- res += [s]
- return numpy.asarray(res)
n_epochs = 200
-x_dim = 4
-h_dim = 9
-num_classes = 4
-
-with open("ctc_test_data.pkl", "rb") as pkl_file:
- try:
- data = pickle.load(pkl_file)
- inputs = data['inputs']
- labels = data['labels']
- # from S x T x B x D to S x T x B
- inputs_mask = numpy.max(data['mask_inputs'], axis=-1)
- labels_mask = data['mask_labels']
- except:
- data = pickle.load(pkl_file, encoding='bytes')
- inputs = data[b'inputs']
- labels = data[b'labels']
- # from S x T x B x D to S x T x B
- inputs_mask = numpy.max(data[b'mask_inputs'], axis=-1)
- labels_mask = data[b'mask_labels']
-
+num_input_classes = 5
+h_dim = 20
+rec_dim = 20
+num_output_classes = 4
print('Building model ...')
-# x : T x B x F
-x = tensor.tensor3('x', dtype=floatX)
-# x_mask : T x B
-x_mask = tensor.matrix('x_mask', dtype=floatX)
+inputt = tensor.lmatrix('input').T
+input_mask = tensor.matrix('input_mask').T
+y = tensor.lmatrix('output').T
+y_mask = tensor.matrix('output_mask').T
+# inputt : T x B
+# input_mask : T x B
# y : L x B
-y = tensor.lmatrix('y')
# y_mask : L x B
-y_mask = tensor.matrix('y_mask', dtype=floatX)
# Linear bricks in
-x_to_h = Linear(name='x_to_h',
- input_dim=x_dim,
- output_dim=h_dim)
-x_transform = x_to_h.apply(x)
+input_to_h = LookupTable(num_input_classes, h_dim, name='lookup')
+h = input_to_h.apply(inputt)
+# h : T x B x h_dim
# RNN bricks
-rnn = SimpleRecurrent(activation=Tanh(),
- dim=h_dim, name="rnn")
-h = rnn.apply(x_transform)
+pre_lstm = Linear(input_dim=h_dim, output_dim=4*rec_dim, name='LSTM_linear')
+lstm = LSTM(activation=Tanh(),
+ dim=rec_dim, name="rnn")
+rnn_out, _ = lstm.apply(pre_lstm.apply(h))
# Linear bricks out
-h_to_o = Linear(name='h_to_o',
- input_dim=h_dim,
- output_dim=num_classes + 1)
-h_transform = h_to_o.apply(h)
+rec_to_o = Linear(name='rec_to_o',
+ input_dim=rec_dim,
+ output_dim=num_output_classes + 1)
+y_hat_pre = rec_to_o.apply(rnn_out)
+# y_hat_pre : T x B x C+1
# y_hat : T x B x C+1
y_hat = tensor.nnet.softmax(
- h_transform.reshape((-1, num_classes + 1))
-).reshape((h.shape[0], h.shape[1], -1))
+ y_hat_pre.reshape((-1, num_output_classes + 1))
+).reshape((y_hat_pre.shape[0], y_hat_pre.shape[1], -1))
y_hat.name = 'y_hat'
-y_hat_mask = x_mask
+y_hat_mask = input_mask
# Cost
-cost = CTC().apply(y, y_hat, y_mask.sum(axis=1), y_hat_mask).mean()
+y_len = y_mask.sum(axis=0)
+cost = CTC().apply(y, y_hat, y_len, y_hat_mask).mean()
cost.name = 'CTC'
+dl, dl_length = CTC().best_path_decoding(y_hat, y_hat_mask)
+L = y.shape[0]
+B = y.shape[1]
+dl = dl[:L, :]
+is_error = tensor.neq(dl, y) * tensor.lt(tensor.arange(L)[:,None], y_len[None,:])
+is_error = tensor.switch(is_error.sum(axis=0), tensor.ones((B,)), tensor.neq(y_len, dl_length))
+
+error_rate = is_error.mean()
+error_rate.name = 'error_rate'
+
# Initialization
-for brick in (rnn, x_to_h, h_to_o):
+for brick in [input_to_h, pre_lstm, lstm, rec_to_o]:
brick.weights_init = IsotropicGaussian(0.01)
brick.biases_init = Constant(0)
brick.initialize()
print('Bulding DataStream ...')
-dataset = IterableDataset({'x': inputs,
- 'x_mask': inputs_mask,
- 'y': labels,
- 'y_mask': labels_mask})
-stream = DataStream(dataset)
+ds, stream = setup_datastream(batch_size=10,
+ nb_examples=1000, rng_seed=123,
+ min_out_len=10, max_out_len=20)
+valid_ds, valid_stream = setup_datastream(batch_size=10,
+ nb_examples=1000, rng_seed=456,
+ min_out_len=10, max_out_len=20)
print('Bulding training process...')
algorithm = GradientDescent(cost=cost,
parameters=ComputationGraph(cost).parameters,
step_rule=CompositeRule([StepClipping(10.0),
Scale(0.02)]))
-monitor_cost = TrainingDataMonitoring([cost],
+monitor_cost = TrainingDataMonitoring([cost, error_rate],
prefix="train",
after_epoch=True)
-# sample number to monitor
-sample = 8
-
-y_hat_max_path = print_pred(tensor.argmax(y_hat[:, sample, :], axis=1))
-y_hat_max_path.name = 'Viterbi'
-monitor_output = TrainingDataMonitoring([y_hat_max_path],
- prefix="y_hat",
- every_n_epochs=1)
-
-length = tensor.sum(y_mask[:, sample]).astype('int32')
-tar = y[:length, sample].astype('int32')
-tar.name = '_Target_Seq'
-monitor_target = TrainingDataMonitoring([tar],
- prefix="y",
- every_n_epochs=1)
+monitor_valid = DataStreamMonitoring([cost, error_rate],
+ data_stream=valid_stream,
+ prefix="valid",
+ after_epoch=True)
model = Model(cost)
main_loop = MainLoop(data_stream=stream, algorithm=algorithm,
- extensions=[monitor_cost, monitor_output,
- monitor_target,
+ extensions=[monitor_cost, monitor_valid,
FinishAfter(after_n_epochs=n_epochs),
Printing()],
model=model)
print('Starting training ...')
main_loop.run()
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :