aboutsummaryrefslogblamecommitdiff
path: root/mohammad/test_ctc.py
blob: a24d6344bb13ce07fdc05592511b3d0b27f5158c (plain) (tree)






































































































































                                                                         
import theano
import numpy
from theano import tensor
from blocks.model import Model
from blocks.bricks import Linear, Tanh
from ctc_cost import CTC
from blocks.initialization import IsotropicGaussian, Constant
from fuel.datasets import IterableDataset
from fuel.streams import DataStream
from blocks.algorithms import (GradientDescent, Scale,
                               StepClipping, CompositeRule)
from blocks.extensions.monitoring import TrainingDataMonitoring
from blocks.main_loop import MainLoop
from blocks.extensions import FinishAfter, Printing
from blocks.bricks.recurrent import SimpleRecurrent
from blocks.graph import ComputationGraph
try:
    import cPickle as pickle
except:
    import pickle

floatX = theano.config.floatX


@theano.compile.ops.as_op(itypes=[tensor.lvector],
                          otypes=[tensor.lvector])
def print_pred(y_hat):
    blank_symbol = 4
    res = []
    for i, s in enumerate(y_hat):
        if (s != blank_symbol) and (i == 0 or s != y_hat[i - 1]):
            res += [s]
    return numpy.asarray(res)

n_epochs = 200
x_dim = 4
h_dim = 9
num_classes = 4

with open("ctc_test_data.pkl", "rb") as pkl_file:
    try:
        data = pickle.load(pkl_file)
        inputs = data['inputs']
        labels = data['labels']
        # from S x T x B x D to S x T x B
        inputs_mask = numpy.max(data['mask_inputs'], axis=-1)
        labels_mask = data['mask_labels']
    except:
        data = pickle.load(pkl_file, encoding='bytes')
        inputs = data[b'inputs']
        labels = data[b'labels']
        # from S x T x B x D to S x T x B
        inputs_mask = numpy.max(data[b'mask_inputs'], axis=-1)
        labels_mask = data[b'mask_labels']



print('Building model ...')
# T x B x F
x = tensor.tensor3('x', dtype=floatX)
# T x B
x_mask = tensor.matrix('x_mask', dtype=floatX)
# L x B
y = tensor.matrix('y', dtype=floatX)
# L x B
y_mask = tensor.matrix('y_mask', dtype=floatX)

x_to_h = Linear(name='x_to_h',
                input_dim=x_dim,
                output_dim=h_dim)
x_transform = x_to_h.apply(x)
rnn = SimpleRecurrent(activation=Tanh(),
                      dim=h_dim, name="rnn")
h = rnn.apply(x_transform)
h_to_o = Linear(name='h_to_o',
                input_dim=h_dim,
                output_dim=num_classes + 1)
h_transform = h_to_o.apply(h)
# T x B x C+1
y_hat = tensor.nnet.softmax(
    h_transform.reshape((-1, num_classes + 1))
).reshape((h.shape[0], h.shape[1], -1))
y_hat.name = 'y_hat'

y_hat_mask = x_mask
cost = CTC().apply(y, y_hat, y_mask, y_hat_mask, 'normal_scale')
cost.name = 'CTC'
# Initialization
for brick in (rnn, x_to_h, h_to_o):
    brick.weights_init = IsotropicGaussian(0.01)
    brick.biases_init = Constant(0)
    brick.initialize()

print('Bulding DataStream ...')
dataset = IterableDataset({'x': inputs,
                           'x_mask': inputs_mask,
                           'y': labels,
                           'y_mask': labels_mask})
stream = DataStream(dataset)

print('Bulding training process...')
algorithm = GradientDescent(cost=cost,
                            parameters=ComputationGraph(cost).parameters,
                            step_rule=CompositeRule([StepClipping(10.0),
                                                     Scale(0.02)]))
monitor_cost = TrainingDataMonitoring([cost],
                                      prefix="train",
                                      after_epoch=True)

# sample number to monitor
sample = 8

y_hat_max_path = print_pred(tensor.argmax(y_hat[:, sample, :], axis=1))
y_hat_max_path.name = 'Viterbi'
monitor_output = TrainingDataMonitoring([y_hat_max_path],
                                        prefix="y_hat",
                                        every_n_epochs=1)

length = tensor.sum(y_mask[:, sample]).astype('int32')
tar = y[:length, sample].astype('int32')
tar.name = '_Target_Seq'
monitor_target = TrainingDataMonitoring([tar],
                                        prefix="y",
                                        every_n_epochs=1)

model = Model(cost)
main_loop = MainLoop(data_stream=stream, algorithm=algorithm,
                     extensions=[monitor_cost, monitor_output,
                                 monitor_target,
                                 FinishAfter(after_n_epochs=n_epochs),
                                 Printing()],
                     model=model)

print('Starting training ...')
main_loop.run()