aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Mesnard <thomas.mesnard@ens.fr>2015-12-23 11:14:57 +0100
committerThomas Mesnard <thomas.mesnard@ens.fr>2015-12-23 11:14:57 +0100
commit08b63add743087ac0e2bbb7f739605642b3edc7b (patch)
treeb253731214ab1c12fd7c8579fbdaad6aa99feebf
parent694964422eef7b835c1bfa3643fdee3bc1cffdd7 (diff)
downloadpgm-ctc-08b63add743087ac0e2bbb7f739605642b3edc7b.tar.gz
pgm-ctc-08b63add743087ac0e2bbb7f739605642b3edc7b.zip
Add main.py
-rw-r--r--ctc.py8
l---------ctc_test_data.pkl1
-rw-r--r--main.py145
3 files changed, 150 insertions, 4 deletions
diff --git a/ctc.py b/ctc.py
index c546cb8..57b2d36 100644
--- a/ctc.py
+++ b/ctc.py
@@ -7,7 +7,7 @@ from blocks.bricks import Brick
# L: OUTPUT_SEQUENCE_LENGTH
# C: NUM_CLASSES
class CTC(Brick):
- def apply(l, probs, l_len=None, probs_mask=None):
+ def apply(self, l, probs, l_len=None, probs_mask=None):
"""
Numeration:
Characters 0 to C-1 are true characters
@@ -63,7 +63,7 @@ class CTC(Brick):
alphabar = prev_alpha + prev_alpha1
alphabar = tensor.switch(l_case2, alphabar + prev_alpha2, alphabar)
next_alpha = alpha_bar * p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S))
- next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha]
+ next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha)
next_c = next_alpha.sum(axis=1)
return next_alpha / next_c[:, None], next_c
@@ -77,7 +77,7 @@ class CTC(Brick):
return tensor.log(c).sum(axis=0)
- def best_path_decoding(probs, probs_mask=None):
+ def best_path_decoding(self, probs, probs_mask=None):
# probs is T x B x C+1
T = probs.shape[0]
B = probs.shape[1]
@@ -106,7 +106,7 @@ class CTC(Brick):
return label[-1], label_length[-1]
- def prefix_search(probs, probs_mask=None):
+ def prefix_search(self, probs, probs_mask=None):
# Hard one...
pass
diff --git a/ctc_test_data.pkl b/ctc_test_data.pkl
new file mode 120000
index 0000000..cd78721
--- /dev/null
+++ b/ctc_test_data.pkl
@@ -0,0 +1 @@
+mohammad/ctc_test_data.pkl \ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..d384edb
--- /dev/null
+++ b/main.py
@@ -0,0 +1,145 @@
+import theano
+import numpy
+from theano import tensor
+from blocks.model import Model
+from blocks.bricks import Linear, Tanh
+from ctc import CTC
+from blocks.initialization import IsotropicGaussian, Constant
+from fuel.datasets import IterableDataset
+from fuel.streams import DataStream
+from blocks.algorithms import (GradientDescent, Scale,
+ StepClipping, CompositeRule)
+from blocks.extensions.monitoring import TrainingDataMonitoring
+from blocks.main_loop import MainLoop
+from blocks.extensions import FinishAfter, Printing
+from blocks.bricks.recurrent import SimpleRecurrent
+from blocks.graph import ComputationGraph
+try:
+ import cPickle as pickle
+except:
+ import pickle
+
+floatX = theano.config.floatX
+
+
+@theano.compile.ops.as_op(itypes=[tensor.lvector],
+ otypes=[tensor.lvector])
+def print_pred(y_hat):
+ blank_symbol = 4
+ res = []
+ for i, s in enumerate(y_hat):
+ if (s != blank_symbol) and (i == 0 or s != y_hat[i - 1]):
+ res += [s]
+ return numpy.asarray(res)
+
+n_epochs = 200
+x_dim = 4
+h_dim = 9
+num_classes = 4
+
+with open("ctc_test_data.pkl", "rb") as pkl_file:
+ try:
+ data = pickle.load(pkl_file)
+ inputs = data['inputs']
+ labels = data['labels']
+ # from S x T x B x D to S x T x B
+ inputs_mask = numpy.max(data['mask_inputs'], axis=-1)
+ labels_mask = data['mask_labels']
+ except:
+ data = pickle.load(pkl_file, encoding='bytes')
+ inputs = data[b'inputs']
+ labels = data[b'labels']
+ # from S x T x B x D to S x T x B
+ inputs_mask = numpy.max(data[b'mask_inputs'], axis=-1)
+ labels_mask = data[b'mask_labels']
+
+
+
+print('Building model ...')
+
+# x : T x B x F
+x = tensor.tensor3('x', dtype=floatX)
+# x_mask : T x B
+x_mask = tensor.matrix('x_mask', dtype=floatX)
+# y : L x B
+y = tensor.lmatrix('y')
+# y_mask : L x B
+y_mask = tensor.matrix('y_mask', dtype=floatX)
+
+# Linear bricks in
+x_to_h = Linear(name='x_to_h',
+ input_dim=x_dim,
+ output_dim=h_dim)
+x_transform = x_to_h.apply(x)
+
+# RNN bricks
+rnn = SimpleRecurrent(activation=Tanh(),
+ dim=h_dim, name="rnn")
+h = rnn.apply(x_transform)
+
+# Linear bricks out
+h_to_o = Linear(name='h_to_o',
+ input_dim=h_dim,
+ output_dim=num_classes + 1)
+h_transform = h_to_o.apply(h)
+
+# y_hat : T x B x C+1
+y_hat = tensor.nnet.softmax(
+ h_transform.reshape((-1, num_classes + 1))
+).reshape((h.shape[0], h.shape[1], -1))
+y_hat.name = 'y_hat'
+
+y_hat_mask = x_mask
+
+# Cost
+cost = CTC().apply(y, y_hat, y_mask, y_hat_mask)
+cost.name = 'CTC'
+
+# Initialization
+for brick in (rnn, x_to_h, h_to_o):
+ brick.weights_init = IsotropicGaussian(0.01)
+ brick.biases_init = Constant(0)
+ brick.initialize()
+
+print('Bulding DataStream ...')
+dataset = IterableDataset({'x': inputs,
+ 'x_mask': inputs_mask,
+ 'y': labels,
+ 'y_mask': labels_mask})
+stream = DataStream(dataset)
+
+print('Bulding training process...')
+algorithm = GradientDescent(cost=cost,
+ parameters=ComputationGraph(cost).parameters,
+ step_rule=CompositeRule([StepClipping(10.0),
+ Scale(0.02)]))
+monitor_cost = TrainingDataMonitoring([cost],
+ prefix="train",
+ after_epoch=True)
+
+# sample number to monitor
+sample = 8
+
+y_hat_max_path = print_pred(tensor.argmax(y_hat[:, sample, :], axis=1))
+y_hat_max_path.name = 'Viterbi'
+monitor_output = TrainingDataMonitoring([y_hat_max_path],
+ prefix="y_hat",
+ every_n_epochs=1)
+
+length = tensor.sum(y_mask[:, sample]).astype('int32')
+tar = y[:length, sample].astype('int32')
+tar.name = '_Target_Seq'
+monitor_target = TrainingDataMonitoring([tar],
+ prefix="y",
+ every_n_epochs=1)
+
+model = Model(cost)
+main_loop = MainLoop(data_stream=stream, algorithm=algorithm,
+ extensions=[monitor_cost, monitor_output,
+ monitor_target,
+ FinishAfter(after_n_epochs=n_epochs),
+ Printing()],
+ model=model)
+
+print('Starting training ...')
+main_loop.run()