aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ctc.py16
-rw-r--r--dummy_dataset.py85
-rw-r--r--main.py139
3 files changed, 152 insertions, 88 deletions
diff --git a/ctc.py b/ctc.py
index f03313b..5f34b2c 100644
--- a/ctc.py
+++ b/ctc.py
@@ -1,5 +1,6 @@
import numpy
+import theano
from theano import tensor, scan
from blocks.bricks import Brick
@@ -65,6 +66,7 @@ class CTC(Brick):
alpha_bar = tensor.switch(l_case2, alpha_bar + prev_alpha_2, alpha_bar)
next_alpha = alpha_bar * p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S))
next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha)
+ next_alpha = next_alpha * tensor.lt(tensor.arange(S)[None,:], (2*l_len+1)[:, None])
next_c = next_alpha.sum(axis=1)
return next_alpha / next_c[:, None], next_c
@@ -85,15 +87,15 @@ class CTC(Brick):
C = probs.shape[2]-1
maxprob = probs.argmax(axis=2)
+ is_double = tensor.eq(maxprob[:-1], maxprob[1:])
+ maxprob = tensor.switch(tensor.concatenate([tensor.zeros((1,B)), is_double]),
+ maxprob, C*tensor.ones_like(maxprob))
# returns two values :
# label : (T x) T x B
# label_length : (T x) B
def recursion(maxp, p_mask, label_length, label):
- label_length = label_length[-1]
- label = label[-1]
-
- nonzero = p_mask * tensor.ne(maxp, C)
+ nonzero = p_mask * tensor.neq(maxp, C)
nonzero_id = nonzero.nonzero()[0]
new_label = tensor.set_subtensor(label[label_length[nonzero_id], nonzero_id], maxp[nonzero_id])
@@ -101,9 +103,9 @@ class CTC(Brick):
return new_label_length, new_label
- label_length, label = tensor.scan(fn=recursion,
- sequences=[maxprob, probs_mask],
- outputs_info=[tensor.zeros((B),dtype='int32'),tensor.zeros((T,B))])
+ [label_length, label], _ = scan(fn=recursion,
+ sequences=[maxprob, probs_mask],
+ outputs_info=[tensor.zeros((B,),dtype='int32'),tensor.zeros((T,B))])
return label[-1], label_length[-1]
diff --git a/dummy_dataset.py b/dummy_dataset.py
new file mode 100644
index 0000000..a65a946
--- /dev/null
+++ b/dummy_dataset.py
@@ -0,0 +1,85 @@
+
+import logging
+import random
+import numpy
+
+import cPickle
+
+from picklable_itertools import iter_
+
+from fuel.datasets import Dataset
+from fuel.streams import DataStream
+from fuel.schemes import IterationScheme, ConstantScheme, SequentialExampleScheme
+from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer
+
+import sys
+import os
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+class DummyDataset(Dataset):
+ def __init__(self, nb_examples, rng_seed, min_out_len, max_out_len, **kwargs):
+ self.provides_sources = ('input', 'output')
+
+ random.seed(rng_seed)
+
+ table = [
+ [0, 1, 2, 3, 4],
+ [0, 1, 2, 1, 0],
+ [4, 3, 2, 3, 4],
+ [4, 3, 2, 1, 0]
+ ]
+ prob0 = 0.7
+ prob = 0.2
+
+ self.data = []
+ for n in range(nb_examples):
+ o = []
+ i = []
+ l = random.randrange(min_out_len, max_out_len)
+ for p in range(l):
+ o.append(random.randrange(len(table)))
+ for x in table[o[-1]]:
+ q = 0
+ if random.uniform(0, 1) < prob0:
+ i.append(x)
+ while random.uniform(0, 1) < prob:
+ i.append(x)
+ self.data.append((i, o))
+
+ super(DummyDataset, self).__init__(**kwargs)
+
+
+ def get_data(self, state=None, request=None):
+ if request is None:
+ raise ValueError("Request required")
+
+ return self.data[request]
+
+# -------------- DATASTREAM SETUP --------------------
+
+def setup_datastream(batch_size, **kwargs):
+ ds = DummyDataset(**kwargs)
+ stream = DataStream(ds, iteration_scheme=SequentialExampleScheme(kwargs['nb_examples']))
+
+ stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))
+ stream = Padding(stream, mask_sources=['input', 'output'])
+
+ return ds, stream
+
+if __name__ == "__main__":
+
+ ds, stream = setup_datastream(nb_examples=5,
+ rng_seed=123,
+ min_out_len=3,
+ max_out_len=6)
+
+ for i, d in enumerate(stream.get_epoch_iterator()):
+ print '--'
+ print d
+
+
+ if i > 2: break
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/main.py b/main.py
index b71d339..e288b9b 100644
--- a/main.py
+++ b/main.py
@@ -3,143 +3,120 @@ import numpy
from theano import tensor
from blocks.model import Model
from blocks.bricks import Linear, Tanh
+from blocks.bricks.lookup import LookupTable
from ctc import CTC
from blocks.initialization import IsotropicGaussian, Constant
from fuel.datasets import IterableDataset
from fuel.streams import DataStream
from blocks.algorithms import (GradientDescent, Scale,
StepClipping, CompositeRule)
-from blocks.extensions.monitoring import TrainingDataMonitoring
+from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring
from blocks.main_loop import MainLoop
from blocks.extensions import FinishAfter, Printing
-from blocks.bricks.recurrent import SimpleRecurrent
+from blocks.bricks.recurrent import SimpleRecurrent, LSTM
from blocks.graph import ComputationGraph
-try:
- import cPickle as pickle
-except:
- import pickle
-floatX = theano.config.floatX
+from dummy_dataset import setup_datastream
+floatX = theano.config.floatX
-@theano.compile.ops.as_op(itypes=[tensor.lvector],
- otypes=[tensor.lvector])
-def print_pred(y_hat):
- blank_symbol = 4
- res = []
- for i, s in enumerate(y_hat):
- if (s != blank_symbol) and (i == 0 or s != y_hat[i - 1]):
- res += [s]
- return numpy.asarray(res)
n_epochs = 200
-x_dim = 4
-h_dim = 9
-num_classes = 4
-
-with open("ctc_test_data.pkl", "rb") as pkl_file:
- try:
- data = pickle.load(pkl_file)
- inputs = data['inputs']
- labels = data['labels']
- # from S x T x B x D to S x T x B
- inputs_mask = numpy.max(data['mask_inputs'], axis=-1)
- labels_mask = data['mask_labels']
- except:
- data = pickle.load(pkl_file, encoding='bytes')
- inputs = data[b'inputs']
- labels = data[b'labels']
- # from S x T x B x D to S x T x B
- inputs_mask = numpy.max(data[b'mask_inputs'], axis=-1)
- labels_mask = data[b'mask_labels']
-
+num_input_classes = 5
+h_dim = 20
+rec_dim = 20
+num_output_classes = 4
print('Building model ...')
-# x : T x B x F
-x = tensor.tensor3('x', dtype=floatX)
-# x_mask : T x B
-x_mask = tensor.matrix('x_mask', dtype=floatX)
+inputt = tensor.lmatrix('input').T
+input_mask = tensor.matrix('input_mask').T
+y = tensor.lmatrix('output').T
+y_mask = tensor.matrix('output_mask').T
+# inputt : T x B
+# input_mask : T x B
# y : L x B
-y = tensor.lmatrix('y')
# y_mask : L x B
-y_mask = tensor.matrix('y_mask', dtype=floatX)
# Linear bricks in
-x_to_h = Linear(name='x_to_h',
- input_dim=x_dim,
- output_dim=h_dim)
-x_transform = x_to_h.apply(x)
+input_to_h = LookupTable(num_input_classes, h_dim, name='lookup')
+h = input_to_h.apply(inputt)
+# h : T x B x h_dim
# RNN bricks
-rnn = SimpleRecurrent(activation=Tanh(),
- dim=h_dim, name="rnn")
-h = rnn.apply(x_transform)
+pre_lstm = Linear(input_dim=h_dim, output_dim=4*rec_dim, name='LSTM_linear')
+lstm = LSTM(activation=Tanh(),
+ dim=rec_dim, name="rnn")
+rnn_out, _ = lstm.apply(pre_lstm.apply(h))
# Linear bricks out
-h_to_o = Linear(name='h_to_o',
- input_dim=h_dim,
- output_dim=num_classes + 1)
-h_transform = h_to_o.apply(h)
+rec_to_o = Linear(name='rec_to_o',
+ input_dim=rec_dim,
+ output_dim=num_output_classes + 1)
+y_hat_pre = rec_to_o.apply(rnn_out)
+# y_hat_pre : T x B x C+1
# y_hat : T x B x C+1
y_hat = tensor.nnet.softmax(
- h_transform.reshape((-1, num_classes + 1))
-).reshape((h.shape[0], h.shape[1], -1))
+ y_hat_pre.reshape((-1, num_output_classes + 1))
+).reshape((y_hat_pre.shape[0], y_hat_pre.shape[1], -1))
y_hat.name = 'y_hat'
-y_hat_mask = x_mask
+y_hat_mask = input_mask
# Cost
-cost = CTC().apply(y, y_hat, y_mask.sum(axis=1), y_hat_mask).mean()
+y_len = y_mask.sum(axis=0)
+cost = CTC().apply(y, y_hat, y_len, y_hat_mask).mean()
cost.name = 'CTC'
+dl, dl_length = CTC().best_path_decoding(y_hat, y_hat_mask)
+L = y.shape[0]
+B = y.shape[1]
+dl = dl[:L, :]
+is_error = tensor.neq(dl, y) * tensor.lt(tensor.arange(L)[:,None], y_len[None,:])
+is_error = tensor.switch(is_error.sum(axis=0), tensor.ones((B,)), tensor.neq(y_len, dl_length))
+
+error_rate = is_error.mean()
+error_rate.name = 'error_rate'
+
# Initialization
-for brick in (rnn, x_to_h, h_to_o):
+for brick in [input_to_h, pre_lstm, lstm, rec_to_o]:
brick.weights_init = IsotropicGaussian(0.01)
brick.biases_init = Constant(0)
brick.initialize()
print('Bulding DataStream ...')
-dataset = IterableDataset({'x': inputs,
- 'x_mask': inputs_mask,
- 'y': labels,
- 'y_mask': labels_mask})
-stream = DataStream(dataset)
+ds, stream = setup_datastream(batch_size=10,
+ nb_examples=1000, rng_seed=123,
+ min_out_len=10, max_out_len=20)
+valid_ds, valid_stream = setup_datastream(batch_size=10,
+ nb_examples=1000, rng_seed=456,
+ min_out_len=10, max_out_len=20)
print('Bulding training process...')
algorithm = GradientDescent(cost=cost,
parameters=ComputationGraph(cost).parameters,
step_rule=CompositeRule([StepClipping(10.0),
Scale(0.02)]))
-monitor_cost = TrainingDataMonitoring([cost],
+monitor_cost = TrainingDataMonitoring([cost, error_rate],
prefix="train",
after_epoch=True)
-# sample number to monitor
-sample = 8
-
-y_hat_max_path = print_pred(tensor.argmax(y_hat[:, sample, :], axis=1))
-y_hat_max_path.name = 'Viterbi'
-monitor_output = TrainingDataMonitoring([y_hat_max_path],
- prefix="y_hat",
- every_n_epochs=1)
-
-length = tensor.sum(y_mask[:, sample]).astype('int32')
-tar = y[:length, sample].astype('int32')
-tar.name = '_Target_Seq'
-monitor_target = TrainingDataMonitoring([tar],
- prefix="y",
- every_n_epochs=1)
+monitor_valid = DataStreamMonitoring([cost, error_rate],
+ data_stream=valid_stream,
+ prefix="valid",
+ after_epoch=True)
model = Model(cost)
main_loop = MainLoop(data_stream=stream, algorithm=algorithm,
- extensions=[monitor_cost, monitor_output,
- monitor_target,
+ extensions=[monitor_cost, monitor_valid,
FinishAfter(after_n_epochs=n_epochs),
Printing()],
model=model)
print('Starting training ...')
main_loop.run()
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :